Files
WTFnet/crates/wtfnet-calc/src/lib.rs
2026-01-16 23:16:58 +08:00

232 lines
6.8 KiB
Rust

use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use serde::{Deserialize, Serialize};
use std::net::{IpAddr, Ipv4Addr};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CalcError {
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("parse error: {0}")]
Parse(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubnetInfo {
pub input: String,
pub version: String,
pub cidr: String,
pub network: String,
pub broadcast: Option<String>,
pub netmask: String,
pub hostmask: String,
pub prefix_len: u8,
pub total_addresses: String,
pub usable_addresses: String,
pub first_host: Option<String>,
pub last_host: Option<String>,
}
pub fn subnet_info(input: &str) -> Result<SubnetInfo, CalcError> {
let net = parse_net(input)?;
match net {
IpNet::V4(v4) => Ok(subnet_info_v4(input, v4)),
IpNet::V6(v6) => Ok(subnet_info_v6(input, v6)),
}
}
pub fn contains(a: &str, b: &str) -> Result<bool, CalcError> {
let net_a = parse_net(a)?;
let net_b = parse_net(b)?;
Ok(net_a.contains(&net_b))
}
pub fn overlap(a: &str, b: &str) -> Result<bool, CalcError> {
let net_a = parse_net(a)?;
let net_b = parse_net(b)?;
match (net_a, net_b) {
(IpNet::V4(a), IpNet::V4(b)) => Ok(overlap_v4(a, b)),
(IpNet::V6(a), IpNet::V6(b)) => Ok(overlap_v6(a, b)),
_ => Ok(false),
}
}
pub fn summarize(inputs: &[String]) -> Result<Vec<IpNet>, CalcError> {
if inputs.is_empty() {
return Err(CalcError::InvalidInput(
"at least one CIDR required".to_string(),
));
}
let mut nets = Vec::with_capacity(inputs.len());
for value in inputs {
nets.push(parse_net(value)?);
}
Ok(IpNet::aggregate(&nets))
}
fn subnet_info_v4(input: &str, net: Ipv4Net) -> SubnetInfo {
let total = total_addresses_v4(net.prefix_len());
let usable = usable_addresses_v4(net.prefix_len());
let (first, last) = first_last_v4(net);
SubnetInfo {
input: input.to_string(),
version: "ipv4".to_string(),
cidr: net.to_string(),
network: net.network().to_string(),
broadcast: Some(net.broadcast().to_string()),
netmask: net.netmask().to_string(),
hostmask: net.hostmask().to_string(),
prefix_len: net.prefix_len(),
total_addresses: total,
usable_addresses: usable,
first_host: first,
last_host: last,
}
}
fn subnet_info_v6(input: &str, net: Ipv6Net) -> SubnetInfo {
let total = total_addresses_v6(net.prefix_len());
let (first, last) = first_last_v6(net);
SubnetInfo {
input: input.to_string(),
version: "ipv6".to_string(),
cidr: net.to_string(),
network: net.network().to_string(),
broadcast: None,
netmask: net.netmask().to_string(),
hostmask: net.hostmask().to_string(),
prefix_len: net.prefix_len(),
total_addresses: total.clone(),
usable_addresses: total,
first_host: first,
last_host: last,
}
}
fn parse_net(value: &str) -> Result<IpNet, CalcError> {
let trimmed = value.trim();
if trimmed.is_empty() {
return Err(CalcError::InvalidInput("empty input".to_string()));
}
let mut parts = trimmed.split_whitespace();
let first = parts.next().unwrap();
if let Some(mask) = parts.next() {
if parts.next().is_some() {
return Err(CalcError::InvalidInput(
"expected: <ip> <mask>".to_string(),
));
}
return parse_ip_mask(first, mask);
}
if let Some((ip, mask)) = trimmed.split_once('/') {
if mask.contains('.') || mask.contains(':') {
return parse_ip_mask(ip, mask);
}
}
trimmed
.parse::<IpNet>()
.map_err(|err| CalcError::Parse(err.to_string()))
}
fn parse_ip_mask(ip: &str, mask: &str) -> Result<IpNet, CalcError> {
let ip: IpAddr = ip
.parse()
.map_err(|_| CalcError::Parse(format!("invalid ip: {ip}")))?;
let mask: IpAddr = mask
.parse()
.map_err(|_| CalcError::Parse(format!("invalid mask: {mask}")))?;
IpNet::with_netmask(ip, mask).map_err(|err| CalcError::Parse(err.to_string()))
}
fn total_addresses_v4(prefix: u8) -> String {
let bits = 32u32.saturating_sub(prefix as u32);
(1u128 << bits).to_string()
}
fn usable_addresses_v4(prefix: u8) -> String {
let total = 1u128 << (32u32.saturating_sub(prefix as u32));
let usable = if prefix <= 30 {
total.saturating_sub(2)
} else {
total
};
usable.to_string()
}
fn total_addresses_v6(prefix: u8) -> String {
let bits = 128u32.saturating_sub(prefix as u32);
if bits == 128 {
return "340282366920938463463374607431768211456".to_string();
}
(1u128 << bits).to_string()
}
fn first_last_v4(net: Ipv4Net) -> (Option<String>, Option<String>) {
let network = net.network();
let broadcast = net.broadcast();
let (first, last) = if net.prefix_len() <= 30 {
(
Some(Ipv4Addr::from(u32::from(network).saturating_add(1)).to_string()),
Some(Ipv4Addr::from(u32::from(broadcast).saturating_sub(1)).to_string()),
)
} else {
(Some(network.to_string()), Some(broadcast.to_string()))
};
(first, last)
}
fn first_last_v6(net: Ipv6Net) -> (Option<String>, Option<String>) {
(
Some(net.network().to_string()),
Some(net.broadcast().to_string()),
)
}
fn overlap_v4(a: Ipv4Net, b: Ipv4Net) -> bool {
let a_start = u32::from(a.network());
let a_end = u32::from(a.broadcast());
let b_start = u32::from(b.network());
let b_end = u32::from(b.broadcast());
a_start <= b_end && b_start <= a_end
}
fn overlap_v6(a: Ipv6Net, b: Ipv6Net) -> bool {
let a_start = u128::from(a.network());
let a_end = u128::from(a.broadcast());
let b_start = u128::from(b.network());
let b_end = u128::from(b.broadcast());
a_start <= b_end && b_start <= a_end
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn subnet_v4_from_mask() {
let info = subnet_info("192.168.1.10 255.255.255.0").expect("subnet");
assert_eq!(info.cidr, "192.168.1.10/24");
assert_eq!(info.network, "192.168.1.0");
assert_eq!(info.broadcast.as_deref(), Some("192.168.1.255"));
assert_eq!(info.usable_addresses, "254");
}
#[test]
fn contains_and_overlap() {
assert!(contains("192.168.0.0/16", "192.168.1.0/24").unwrap());
assert!(overlap("10.0.0.0/24", "10.0.0.128/25").unwrap());
assert!(!overlap("10.0.0.0/24", "10.0.1.0/24").unwrap());
}
#[test]
fn summarize_ipv4() {
let inputs = vec!["10.0.0.0/24".to_string(), "10.0.1.0/24".to_string()];
let result = summarize(&inputs).expect("summarize");
assert_eq!(result.len(), 1);
assert_eq!(result[0].to_string(), "10.0.0.0/23");
}
}