src/vrrpv2.rs
author Sunil Nimmagadda <sunil@nimmagadda.net>
Sun, 26 May 2024 13:19:51 +0530
changeset 29 277a2f8b3653
parent 26 4ad31d279c35
permissions -rw-r--r--
Make VRRPv2Error std::error::Error trait compliant.

use std::fmt::Display;
use std::io::Read;
use std::io::{self, Cursor};
use std::net::Ipv4Addr;

/// A VRRP version 2 packet.
///
/// Packet format
///
///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |Version| Type  | Virtual Rtr ID|   Priority    | Count IP Addrs|
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |   Auth Type   |   Adver Int   |          Checksum             |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                         IP Address (1)                        |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                            .                                  |
/// |                            .                                  |
/// |                            .                                  |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                         IP Address (n)                        |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                     Authentication Data (1)                   |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
/// |                     Authentication Data (2)                   |
/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
#[derive(Debug, PartialEq)]
pub struct VRRPv2 {
    pub virtual_router_id: u8,
    pub priority: u8,
    pub count_ip_addrs: u8,
    pub auth_type: VRRPv2AuthType,
    pub advertisement_interval: u8,
    pub checksum: u16,
    pub ip_addrs: Vec<Ipv4Addr>,
}

#[derive(Debug, Clone, PartialEq)]
pub enum VRRPv2Error {
    InvalidAuthType,
    InvalidChecksum,
    InvalidType,
    InvalidVersion,
    ParseError,
}

impl Display for VRRPv2Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::InvalidAuthType => write!(f, "Invalid Auth Type"),
            Self::InvalidChecksum => write!(f, "Invalid Checksum"),
            Self::InvalidType => write!(f, "Invalid Type"),
            Self::InvalidVersion => write!(f, "Invalid Version"),
            Self::ParseError => write!(f, "Parse Error"),
        }
    }
}

#[derive(Debug, PartialEq)]
pub enum VRRPv2AuthType {
    VRRPv2AuthNoAuth = 0x00,
    VRRPv2AuthReserved1 = 0x01,
    VRRPv2AuthReserved2 = 0x02,
}

trait BytesReader {
    fn read_u8(&mut self) -> io::Result<u8>;
    fn read_u16(&mut self) -> io::Result<u16>;
    fn read_u32(&mut self) -> io::Result<u32>;
}

impl<T: AsRef<[u8]>> BytesReader for Cursor<T> {
    fn read_u8(&mut self) -> io::Result<u8> {
        let mut buffer = [0; 1];
        self.read_exact(&mut buffer)?;
        Ok(u8::from_be_bytes(buffer))
    }

    fn read_u16(&mut self) -> io::Result<u16> {
        let mut buffer = [0; 2];
        self.read_exact(&mut buffer)?;
        Ok(u16::from_be_bytes(buffer))
    }

    fn read_u32(&mut self) -> io::Result<u32> {
        let mut buffer = [0; 4];
        self.read_exact(&mut buffer)?;
        Ok(u32::from_be_bytes(buffer))
    }
}

fn parse(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
    let mut rdr = Cursor::new(bytes);
    let Ok(vertype) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    if (vertype & 0xF) != 1 {
        return Err(VRRPv2Error::InvalidType);
    }
    if (vertype >> 4) != 2 {
        return Err(VRRPv2Error::InvalidVersion);
    }
    let Ok(virtual_router_id) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    let Ok(priority) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    let Ok(count_ip_addrs) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    let Ok(auth_type) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    let auth_type = match auth_type {
        0 => VRRPv2AuthType::VRRPv2AuthNoAuth,
        1 => VRRPv2AuthType::VRRPv2AuthReserved1,
        2 => VRRPv2AuthType::VRRPv2AuthReserved2,
        _ => return Err(VRRPv2Error::InvalidAuthType),
    };
    let Ok(advertisement_interval) = rdr.read_u8() else {
        return Err(VRRPv2Error::ParseError);
    };
    let Ok(checksum) = rdr.read_u16() else {
        return Err(VRRPv2Error::ParseError);
    };
    let mut ip_addrs = Vec::with_capacity(count_ip_addrs as usize);
    for _i in 0..count_ip_addrs {
        let Ok(b) = rdr.read_u32() else {
            return Err(VRRPv2Error::ParseError);
        };
        ip_addrs.push(Ipv4Addr::from(b));
    }
    Ok(VRRPv2 {
        virtual_router_id,
        priority,
        count_ip_addrs,
        auth_type,
        advertisement_interval,
        checksum,
        ip_addrs,
    })
}

/// Parse and validate a byte array to construct a VRRPv2 struct.
///
/// # Examples
///
/// ```
/// use vrrpd::vrrpv2::VRRPv2;
/// use vrrpd::vrrpv2::VRRPv2AuthType;
/// use vrrpd::vrrpv2::from_bytes;
/// use std::net::Ipv4Addr;
///
/// let bytes = [
///    0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
///    0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
/// ];
/// let expected = VRRPv2 {
///     virtual_router_id: 1,
///     priority: 100,
///     count_ip_addrs: 1,
///     auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
///     checksum: 47698,
///     advertisement_interval: 1,
///     ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
/// };
/// assert_eq!(from_bytes(&bytes), Ok(expected));
/// ```
pub fn from_bytes(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
    let vrrpv2 = parse(bytes)?;
    if checksum(bytes) != 0 {
        return Err(VRRPv2Error::InvalidChecksum);
    }
    Ok(vrrpv2)
}

fn checksum(bytes: &[u8]) -> u16 {
    let (chunks, remainder) = bytes.as_chunks::<2>();
    let mut sum = chunks
        .iter()
        .fold(0, |acc, x| acc + u32::from(u16::from_be_bytes(*x)));
    if !remainder.is_empty() {
        sum += u32::from(remainder[0]);
    }
    while (sum >> 16) > 0 {
        sum = (sum & 0xffff) + (sum >> 16);
    }
    !(sum as u16)
}

#[test]
fn test_incomplete_bytes() {
    let bytes = [0x21, 0x01];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::ParseError));
}

#[test]
fn test_invalid_version() {
    let bytes = [
        0x31, 0x1, 0x2a, 0x0, 0x0, 0x1, 0xb5, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
        0x0, 0x0, 0x0,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidVersion));
}

#[test]
fn test_invalid_type() {
    let bytes = [
        0x20, 0x2a, 0x64, 0x1, 0x0, 0x1, 0xaa, 0x29, 0xc0, 0xa8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0,
        0x0, 0x0, 0x0,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidType));
}

#[test]
fn test_invalid_auth_type() {
    let bytes = [
        0x21, 0x01, 0x64, 0x01, 0x03, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidAuthType));
}

#[test]
fn test_invalid_checksum() {
    let bytes = [
        0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xbb, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
    ];
    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidChecksum));
}

#[test]
fn test_checksum() {
    let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
    assert_eq!(checksum(&bytes), 0x220d);
}

#[test]
fn test_checksum_singlebyte() {
    let bytes = [0; 1];
    assert_eq!(checksum(&bytes), 0xffff);
}

#[test]
fn test_checksum_twobytes() {
    let bytes = [0x00, 0xff];
    assert_eq!(checksum(&bytes), 0xff00);
}

#[test]
fn test_checksum_another() {
    let bytes = [0xe3, 0x4f, 0x23, 0x96, 0x44, 0x27, 0x99, 0xf3];
    assert_eq!(checksum(&bytes), 0x1aff);
}