Add multiple features
This commit is contained in:
335
crates/wtfnet-tls/src/lib.rs
Normal file
335
crates/wtfnet-tls/src/lib.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
use rustls::{Certificate, ClientConfig, RootCertStore, ServerName};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use x509_parser::prelude::{FromDer, X509Certificate};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TlsError {
|
||||
#[error("invalid target: {0}")]
|
||||
InvalidTarget(String),
|
||||
#[error("invalid sni: {0}")]
|
||||
InvalidSni(String),
|
||||
#[error("io error: {0}")]
|
||||
Io(String),
|
||||
#[error("tls error: {0}")]
|
||||
Tls(String),
|
||||
#[error("parse error: {0}")]
|
||||
Parse(String),
|
||||
#[error("timeout")]
|
||||
Timeout,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsCertSummary {
|
||||
pub subject: String,
|
||||
pub issuer: String,
|
||||
pub not_before: String,
|
||||
pub not_after: String,
|
||||
pub san: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsHandshakeReport {
|
||||
pub target: String,
|
||||
pub sni: Option<String>,
|
||||
pub alpn_offered: Vec<String>,
|
||||
pub alpn_negotiated: Option<String>,
|
||||
pub tls_version: Option<String>,
|
||||
pub cipher: Option<String>,
|
||||
pub cert_chain: Vec<TlsCertSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsVerifyReport {
|
||||
pub target: String,
|
||||
pub sni: Option<String>,
|
||||
pub alpn_offered: Vec<String>,
|
||||
pub alpn_negotiated: Option<String>,
|
||||
pub tls_version: Option<String>,
|
||||
pub cipher: Option<String>,
|
||||
pub verified: bool,
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsCertReport {
|
||||
pub target: String,
|
||||
pub sni: Option<String>,
|
||||
pub cert_chain: Vec<TlsCertSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsAlpnReport {
|
||||
pub target: String,
|
||||
pub sni: Option<String>,
|
||||
pub alpn_offered: Vec<String>,
|
||||
pub alpn_negotiated: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TlsOptions {
|
||||
pub sni: Option<String>,
|
||||
pub alpn: Vec<String>,
|
||||
pub timeout_ms: u64,
|
||||
pub insecure: bool,
|
||||
}
|
||||
|
||||
pub async fn handshake(target: &str, options: TlsOptions) -> Result<TlsHandshakeReport, TlsError> {
|
||||
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
|
||||
let connector = build_connector(options.insecure, &options.alpn)?;
|
||||
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
|
||||
let (_, session) = stream.get_ref();
|
||||
|
||||
Ok(TlsHandshakeReport {
|
||||
target: target.to_string(),
|
||||
sni: options.sni,
|
||||
alpn_offered: options.alpn.clone(),
|
||||
alpn_negotiated: session
|
||||
.alpn_protocol()
|
||||
.map(|value| String::from_utf8_lossy(value).to_string()),
|
||||
tls_version: session.protocol_version().map(|v| format!("{v:?}")),
|
||||
cipher: session
|
||||
.negotiated_cipher_suite()
|
||||
.map(|suite| format!("{suite:?}")),
|
||||
cert_chain: extract_cert_chain(session.peer_certificates())?,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn verify(target: &str, options: TlsOptions) -> Result<TlsVerifyReport, TlsError> {
|
||||
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
|
||||
let connector = build_connector(false, &options.alpn)?;
|
||||
match connect(addr, connector, server_name, options.timeout_ms).await {
|
||||
Ok(stream) => {
|
||||
let (_, session) = stream.get_ref();
|
||||
Ok(TlsVerifyReport {
|
||||
target: target.to_string(),
|
||||
sni: options.sni,
|
||||
alpn_offered: options.alpn.clone(),
|
||||
alpn_negotiated: session
|
||||
.alpn_protocol()
|
||||
.map(|value| String::from_utf8_lossy(value).to_string()),
|
||||
tls_version: session.protocol_version().map(|v| format!("{v:?}")),
|
||||
cipher: session
|
||||
.negotiated_cipher_suite()
|
||||
.map(|suite| format!("{suite:?}")),
|
||||
verified: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(err) => Ok(TlsVerifyReport {
|
||||
target: target.to_string(),
|
||||
sni: options.sni,
|
||||
alpn_offered: options.alpn.clone(),
|
||||
alpn_negotiated: None,
|
||||
tls_version: None,
|
||||
cipher: None,
|
||||
verified: false,
|
||||
error: Some(err.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn certs(target: &str, options: TlsOptions) -> Result<TlsCertReport, TlsError> {
|
||||
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
|
||||
let connector = build_connector(options.insecure, &options.alpn)?;
|
||||
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
|
||||
let (_, session) = stream.get_ref();
|
||||
Ok(TlsCertReport {
|
||||
target: target.to_string(),
|
||||
sni: options.sni,
|
||||
cert_chain: extract_cert_chain(session.peer_certificates())?,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn alpn(target: &str, options: TlsOptions) -> Result<TlsAlpnReport, TlsError> {
|
||||
let (addr, server_name) = parse_target(target, options.sni.as_deref())?;
|
||||
let connector = build_connector(options.insecure, &options.alpn)?;
|
||||
let stream = connect(addr, connector, server_name, options.timeout_ms).await?;
|
||||
let (_, session) = stream.get_ref();
|
||||
Ok(TlsAlpnReport {
|
||||
target: target.to_string(),
|
||||
sni: options.sni,
|
||||
alpn_offered: options.alpn.clone(),
|
||||
alpn_negotiated: session
|
||||
.alpn_protocol()
|
||||
.map(|value| String::from_utf8_lossy(value).to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_target(target: &str, sni: Option<&str>) -> Result<(SocketAddr, ServerName), TlsError> {
|
||||
let (host, port) = split_host_port(target)?;
|
||||
let addr = resolve_addr(&host, port)?;
|
||||
let server_name = if let Some(sni) = sni {
|
||||
ServerName::try_from(sni).map_err(|_| TlsError::InvalidSni(sni.to_string()))?
|
||||
} else if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
ServerName::IpAddress(ip)
|
||||
} else {
|
||||
ServerName::try_from(host.as_str())
|
||||
.map_err(|_| TlsError::InvalidSni(host.to_string()))?
|
||||
};
|
||||
Ok((addr, server_name))
|
||||
}
|
||||
|
||||
fn split_host_port(value: &str) -> Result<(String, u16), TlsError> {
|
||||
if let Some(stripped) = value.strip_prefix('[') {
|
||||
if let Some(end) = stripped.find(']') {
|
||||
let host = &stripped[..end];
|
||||
let rest = &stripped[end + 1..];
|
||||
let port = rest
|
||||
.strip_prefix(':')
|
||||
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
|
||||
let port = port
|
||||
.parse::<u16>()
|
||||
.map_err(|_| TlsError::InvalidTarget(value.to_string()))?;
|
||||
return Ok((host.to_string(), port));
|
||||
}
|
||||
}
|
||||
|
||||
let mut parts = value.rsplitn(2, ':');
|
||||
let port = parts
|
||||
.next()
|
||||
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
|
||||
let host = parts
|
||||
.next()
|
||||
.ok_or_else(|| TlsError::InvalidTarget(value.to_string()))?;
|
||||
if host.contains(':') {
|
||||
return Err(TlsError::InvalidTarget(value.to_string()));
|
||||
}
|
||||
let port = port
|
||||
.parse::<u16>()
|
||||
.map_err(|_| TlsError::InvalidTarget(value.to_string()))?;
|
||||
Ok((host.to_string(), port))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn split_host_port_ipv4() {
|
||||
let (host, port) = split_host_port("example.com:443").unwrap();
|
||||
assert_eq!(host, "example.com");
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn split_host_port_ipv6() {
|
||||
let (host, port) = split_host_port("[2001:db8::1]:443").unwrap();
|
||||
assert_eq!(host, "2001:db8::1");
|
||||
assert_eq!(port, 443);
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_addr(host: &str, port: u16) -> Result<SocketAddr, TlsError> {
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return Ok(SocketAddr::new(ip, port));
|
||||
}
|
||||
let addr = std::net::ToSocketAddrs::to_socket_addrs(&(host, port))
|
||||
.map_err(|err| TlsError::Io(err.to_string()))?
|
||||
.next()
|
||||
.ok_or_else(|| TlsError::InvalidTarget(host.to_string()))?;
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
fn build_connector(insecure: bool, alpn: &[String]) -> Result<TlsConnector, TlsError> {
|
||||
let mut config = if insecure {
|
||||
ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(Arc::new(NoVerifier))
|
||||
.with_no_client_auth()
|
||||
} else {
|
||||
let mut roots = RootCertStore::empty();
|
||||
let store = rustls_native_certs::load_native_certs()
|
||||
.map_err(|err| TlsError::Io(err.to_string()))?;
|
||||
for cert in store {
|
||||
roots
|
||||
.add(&Certificate(cert.0))
|
||||
.map_err(|err| TlsError::Tls(err.to_string()))?;
|
||||
}
|
||||
ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth()
|
||||
};
|
||||
|
||||
if !alpn.is_empty() {
|
||||
config.alpn_protocols = alpn.iter().map(|p| p.as_bytes().to_vec()).collect();
|
||||
}
|
||||
|
||||
Ok(TlsConnector::from(Arc::new(config)))
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
addr: SocketAddr,
|
||||
connector: TlsConnector,
|
||||
server_name: ServerName,
|
||||
timeout_ms: u64,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, TlsError> {
|
||||
let tcp = timeout(Duration::from_millis(timeout_ms), TcpStream::connect(addr))
|
||||
.await
|
||||
.map_err(|_| TlsError::Timeout)?
|
||||
.map_err(|err| TlsError::Io(err.to_string()))?;
|
||||
let stream = timeout(
|
||||
Duration::from_millis(timeout_ms),
|
||||
connector.connect(server_name, tcp),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| TlsError::Timeout)?
|
||||
.map_err(|err| TlsError::Tls(err.to_string()))?;
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
fn extract_cert_chain(certs: Option<&[Certificate]>) -> Result<Vec<TlsCertSummary>, TlsError> {
|
||||
let mut results = Vec::new();
|
||||
if let Some(certs) = certs {
|
||||
for cert in certs {
|
||||
let summary = parse_cert(&cert.0)?;
|
||||
results.push(summary);
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn parse_cert(der: &[u8]) -> Result<TlsCertSummary, TlsError> {
|
||||
let (_, cert) =
|
||||
X509Certificate::from_der(der).map_err(|err| TlsError::Parse(err.to_string()))?;
|
||||
Ok(TlsCertSummary {
|
||||
subject: cert.subject().to_string(),
|
||||
issuer: cert.issuer().to_string(),
|
||||
not_before: cert.validity().not_before.to_string(),
|
||||
not_after: cert.validity().not_after.to_string(),
|
||||
san: extract_san(&cert),
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_san(cert: &X509Certificate<'_>) -> Vec<String> {
|
||||
let mut result = Vec::new();
|
||||
if let Ok(Some(ext)) = cert.subject_alternative_name() {
|
||||
for name in ext.value.general_names.iter() {
|
||||
result.push(name.to_string());
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
struct NoVerifier;
|
||||
|
||||
impl rustls::client::ServerCertVerifier for NoVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
|
||||
Ok(rustls::client::ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user