Recover a lost doc test.
//! Parser recognising a VRRP v2 packet.
use nom::error::Error;
use nom::multi::count;
use nom::number::complete::{be_u16, be_u32, u8};
use nom::{bits::complete::take, IResult};
use std::net::Ipv4Addr;
type BitInput<'a> = (&'a [u8], usize);
/// A VRRP version 2 packet.
///
/// Packet format
///
/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |Version| Type | Virtual Rtr ID| Priority | Count IP Addrs|
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | Auth Type | Adver Int | Checksum |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | IP Address (1) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | . |
/// | . |
/// | . |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | IP Address (n) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | Authentication Data (1) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// | Authentication Data (2) |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
#[derive(Debug, PartialEq)]
pub struct VRRPv2 {
pub virtual_router_id: u8,
pub priority: u8,
pub count_ip_addrs: u8,
pub auth_type: VRRPv2AuthType,
pub advertisement_interval: u8,
pub checksum: u16,
pub ip_addrs: Vec<Ipv4Addr>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum VRRPv2Error {
InvalidAuthType,
InvalidChecksum,
InvalidType,
InvalidVersion,
ParseError,
}
#[derive(Debug, PartialEq)]
pub enum VRRPv2AuthType {
VRRPv2AuthNoAuth = 0x00,
VRRPv2AuthReserved1 = 0x01,
VRRPv2AuthReserved2 = 0x02,
}
/// Helper function to let compiler infer generic parameters.
fn take_nibble(input: BitInput) -> IResult<BitInput, u8> {
take(4usize)(input)
}
fn parse(input: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
let Ok(((input, _), version)) = take_nibble((input, 0)) else {
return Err(VRRPv2Error::ParseError);
};
if version != 2 {
return Err(VRRPv2Error::InvalidVersion);
}
let Ok(((input, _), type_)) = take_nibble((input, 4)) else {
return Err(VRRPv2Error::ParseError);
};
//Advertisement
if type_ != 1 {
return Err(VRRPv2Error::InvalidType);
}
let Ok((input, virtual_router_id)) = u8::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let Ok((input, priority)) = u8::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let Ok((input, count_ip_addrs)) = u8::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let Ok((input, auth_type)) = u8::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let auth_type = match auth_type {
0 => VRRPv2AuthType::VRRPv2AuthNoAuth,
1 => VRRPv2AuthType::VRRPv2AuthReserved1,
2 => VRRPv2AuthType::VRRPv2AuthReserved2,
_ => return Err(VRRPv2Error::InvalidAuthType),
};
let Ok((input, advertisement_interval)) = u8::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let Ok((input, checksum)) = be_u16::<&[u8], Error<&[u8]>>(input) else {
return Err(VRRPv2Error::ParseError);
};
let Ok((_, xs)) = count(be_u32::<&[u8], Error<&[u8]>>, usize::from(count_ip_addrs))(input)
else {
return Err(VRRPv2Error::ParseError);
};
let ip_addrs: Vec<Ipv4Addr> = xs.into_iter().map(Ipv4Addr::from).collect();
Ok(VRRPv2 {
virtual_router_id,
priority,
count_ip_addrs,
auth_type,
advertisement_interval,
checksum,
ip_addrs,
})
}
// nightly has as_chunks that allows for a nicer code...
// let (chunks, remainder) = bytes.as_chunks(2);
// fold chunks and remainder without an if.
fn checksum(bytes: &[u8]) -> u16 {
let mut sum: u32 = 0;
for chunk in bytes.chunks(2) {
// Left over byte if any
if chunk.len() == 1 {
sum += u32::from(chunk[0]);
} else {
sum += u32::from(u16::from_be_bytes(chunk.try_into().unwrap()));
}
}
while (sum >> 16) > 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
/// Parse and validate a byte array to construct a VRRPv2 struct.
///
/// # Examples
///
/// ```
/// use vrrpd::vrrpv2::VRRPv2;
/// use vrrpd::vrrpv2::VRRPv2AuthType;
/// use vrrpd::vrrpv2::from_bytes;
/// use std::net::Ipv4Addr;
///
/// let bytes = [
/// 0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
/// 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
/// ];
/// let expected = VRRPv2 {
/// virtual_router_id: 1,
/// priority: 100,
/// count_ip_addrs: 1,
/// auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
/// checksum: 47698,
/// advertisement_interval: 1,
/// ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
/// };
/// assert_eq!(from_bytes(&bytes), Ok(expected));
/// ```
pub fn from_bytes(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
let vrrpv2 = parse(bytes)?;
if checksum(bytes) != 0 {
return Err(VRRPv2Error::InvalidChecksum);
}
Ok(vrrpv2)
}
#[test]
fn test_incomplete_bytes() {
let bytes = [0x21, 0x01];
assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::ParseError));
}
#[test]
fn test_invalid_version() {
let bytes = [
0x31, 0x2a, 0x64, 0x1, 0x0, 0x1, 0xaa, 0x29, 0xc0, 0xa8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0,
];
assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidVersion));
}
#[test]
fn test_invalid_type() {
let bytes = [
0x20, 0x1, 0x2a, 0x0, 0x0, 0x1, 0xb5, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0,
];
assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidType));
}
#[test]
fn test_invalid_auth_type() {
let bytes = [
0x21, 0x01, 0x64, 0x01, 0x03, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidAuthType));
}
#[test]
fn test_invalid_checksum() {
let bytes = [
0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xbb, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidChecksum));
}
#[test]
fn test_checksum() {
let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
assert_eq!(checksum(&bytes), 0x220d);
}
#[test]
fn test_checksum_singlebyte() {
let bytes = [0; 1];
assert_eq!(checksum(&bytes), 0xffff);
}
#[test]
fn test_checksum_twobytes() {
let bytes = [0x00, 0xff];
assert_eq!(checksum(&bytes), 0xff00);
}
#[test]
fn test_checksum_another() {
let bytes = [0xe3, 0x4f, 0x23, 0x96, 0x44, 0x27, 0x99, 0xf3];
assert_eq!(checksum(&bytes), 0x1aff);
}