Refactor parser.
authorSunil Nimmagadda <sunil@nimmagadda.net>
Thu, 28 Dec 2023 23:17:24 +0530
changeset 16 8c8be538d0e6
parent 15 4e56fc00d06e
child 17 90d097c9ea62
Refactor parser. Simplify error handling and provide better error Enum to report parsing errors accurately.
src/vrrpv2.rs
--- a/src/vrrpv2.rs	Sun Jul 09 02:40:54 2023 +0530
+++ b/src/vrrpv2.rs	Thu Dec 28 23:17:24 2023 +0530
@@ -1,35 +1,52 @@
-use nom::bits::{bits, streaming::take};
-use nom::combinator::map_res;
-use nom::error::{Error, ErrorKind};
+//! Parser recognising a VRRP v2 packet.
+use nom::error::Error;
 use nom::multi::count;
 use nom::number::complete::{be_u16, be_u32, u8};
-use nom::sequence::tuple;
-use nom::{Err, IResult};
+use nom::{bits::complete::take, IResult};
 use std::net::Ipv4Addr;
 
-const VRRP_REQUIRED_VERSION: u8 = 2;
-const VRRP_REQUIRED_TYPE: u8 = 1; // Advertisement
+type BitInput<'a> = (&'a [u8], usize);
+
+/// A VRRP version 2 packet.
+///
+/// Packet format
+///
+///  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |Version| Type  | Virtual Rtr ID|   Priority    | Count IP Addrs|
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |   Auth Type   |   Adver Int   |          Checksum             |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |                         IP Address (1)                        |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |                            .                                  |
+/// |                            .                                  |
+/// |                            .                                  |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |                         IP Address (n)                        |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |                     Authentication Data (1)                   |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+/// |                     Authentication Data (2)                   |
+/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+#[derive(Debug, PartialEq)]
+pub struct VRRPv2 {
+    pub virtual_router_id: u8,
+    pub priority: u8,
+    pub count_ip_addrs: u8,
+    pub auth_type: VRRPv2AuthType,
+    pub advertisement_interval: u8,
+    pub checksum: u16,
+    pub ip_addrs: Vec<Ipv4Addr>,
+}
 
 #[derive(Debug, Clone, PartialEq)]
 pub enum VRRPv2Error {
-    VRRPv2ParseError,
-}
-
-type NomError<'a> = nom::Err<nom::error::Error<&'a [u8]>>;
-impl From<NomError<'_>> for VRRPv2Error {
-    fn from(_: NomError) -> Self {
-        Self::VRRPv2ParseError
-    }
-}
-
-#[derive(Debug, PartialEq)]
-pub enum VRRPVersion {
-    V2,
-}
-
-#[derive(Debug, PartialEq)]
-pub enum VRRPv2Type {
-    VRRPv2Advertisement,
+    InvalidAuthType,
+    InvalidChecksum,
+    InvalidType,
+    InvalidVersion,
+    ParseError,
 }
 
 #[derive(Debug, PartialEq)]
@@ -39,68 +56,63 @@
     VRRPv2AuthReserved2 = 0x02,
 }
 
-#[derive(Debug, PartialEq)]
-pub struct VRRPv2 {
-    pub version: VRRPVersion,
-    pub type_: VRRPv2Type,
-    pub virtual_router_id: u8,
-    pub priority: u8,
-    pub count_ip_addrs: u8,
-    pub auth_type: VRRPv2AuthType,
-    pub advertisement_interval: u8,
-    pub checksum: u16,
-    pub ip_addrs: Vec<Ipv4Addr>,
-}
-
-fn two_nibbles(input: &[u8]) -> IResult<&[u8], (u8, u8)> {
-    bits::<_, _, Error<(&[u8], usize)>, _, _>(tuple((take(4usize), take(4usize))))(input)
-}
-
-fn parse_version_type(input: &[u8]) -> IResult<&[u8], (VRRPVersion, VRRPv2Type)> {
-    let (input, pair) = two_nibbles(input)?;
-    match pair {
-        (VRRP_REQUIRED_VERSION, VRRP_REQUIRED_TYPE) => {
-            Ok((input, (VRRPVersion::V2, VRRPv2Type::VRRPv2Advertisement)))
-        }
-        _ => Err(Err::Error(Error::new(input, ErrorKind::Alt))),
-    }
+/// Helper function to let compiler infer generic parameters.
+fn take_nibble(input: BitInput) -> IResult<BitInput, u8> {
+    take(4usize)(input)
 }
 
-fn parse_auth_type(input: &[u8]) -> IResult<&[u8], VRRPv2AuthType> {
-    map_res(u8, |auth_type| {
-        Ok(match auth_type {
-            0 => VRRPv2AuthType::VRRPv2AuthNoAuth,
-            1 => VRRPv2AuthType::VRRPv2AuthReserved1,
-            2 => VRRPv2AuthType::VRRPv2AuthReserved2,
-            _ => return Err(Err::Error(Error::new(input, ErrorKind::Alt))),
-        })
-    })(input)
-}
-
-fn parse(input: &[u8]) -> IResult<&[u8], VRRPv2> {
-    let (input, (version, type_)) = parse_version_type(input)?;
-    let (input, virtual_router_id) = u8(input)?;
-    let (input, priority) = u8(input)?;
-    let (input, count_ip_addrs) = u8(input)?;
-    let (input, auth_type) = parse_auth_type(input)?;
-    let (input, advertisement_interval) = u8(input)?;
-    let (input, checksum) = be_u16(input)?;
-    let (input, xs) = count(be_u32, usize::from(count_ip_addrs))(input)?;
-    let ip_addrs = xs.into_iter().map(Ipv4Addr::from).collect();
-    Ok((
-        input,
-        VRRPv2 {
-            version,
-            type_,
-            virtual_router_id,
-            priority,
-            count_ip_addrs,
-            auth_type,
-            advertisement_interval,
-            checksum,
-            ip_addrs,
-        },
-    ))
+fn parse(input: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
+    let Ok(((input, _), version)) = take_nibble((input, 0)) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    if version != 2 {
+        return Err(VRRPv2Error::InvalidVersion);
+    }
+    let Ok(((input, _), type_)) = take_nibble((input, 4)) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    //Advertisement
+    if type_ != 1 {
+        return Err(VRRPv2Error::InvalidType);
+    }
+    let Ok((input, virtual_router_id)) = u8::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let Ok((input, priority)) = u8::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let Ok((input, count_ip_addrs)) = u8::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let Ok((input, auth_type)) = u8::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let auth_type = match auth_type {
+        0 => VRRPv2AuthType::VRRPv2AuthNoAuth,
+        1 => VRRPv2AuthType::VRRPv2AuthReserved1,
+        2 => VRRPv2AuthType::VRRPv2AuthReserved2,
+        _ => return Err(VRRPv2Error::InvalidAuthType),
+    };
+    let Ok((input, advertisement_interval)) = u8::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let Ok((input, checksum)) = be_u16::<&[u8], Error<&[u8]>>(input) else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let Ok((_, xs)) = count(be_u32::<&[u8], Error<&[u8]>>, usize::from(count_ip_addrs))(input)
+    else {
+        return Err(VRRPv2Error::ParseError);
+    };
+    let ip_addrs: Vec<Ipv4Addr> = xs.into_iter().map(Ipv4Addr::from).collect();
+    Ok(VRRPv2 {
+        virtual_router_id,
+        priority,
+        count_ip_addrs,
+        auth_type,
+        advertisement_interval,
+        checksum,
+        ip_addrs,
+    })
 }
 
 // Nightly has a nicer array_chunks API to express it more succinctly.
@@ -127,54 +139,65 @@
     checksum == 0
 }
 
+/// Parse and validate a byte array to construct a VRRPv2 struct.
+///
+/// # Examples
+///
+/// ```
+/// use vrrpd::vrrpv2::VRRPv2;
+/// use vrrpd::vrrpv2::VRRPv2AuthType;
+/// use vrrpd::vrrpv2::from_bytes;
+/// use std::net::Ipv4Addr;
+///
+/// let bytes = [
+///    0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
+///    0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+/// ];
+/// let expected = VRRPv2 {
+///     virtual_router_id: 1,
+///     priority: 100,
+///     count_ip_addrs: 1,
+///     auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
+///     checksum: 47698,
+///     advertisement_interval: 1,
+///     ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
+/// };
+/// assert_eq!(from_bytes(&bytes), Ok(expected));
+/// ```
 pub fn from_bytes(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
+    let vrrpv2 = parse(bytes)?;
     if !validate_checksum(bytes) {
-        return Err(VRRPv2Error::VRRPv2ParseError);
-    }
-    match parse(bytes) {
-        Ok((_, v)) => Ok(v),
-        Err(e) => Err(e.into()),
+        return Err(VRRPv2Error::InvalidChecksum);
     }
-}
-
-#[test]
-fn test_standard_bytes() {
-    let bytes = [
-        0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
-        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-    ];
-    let got = from_bytes(&bytes).unwrap();
-    let expected = VRRPv2 {
-        version: VRRPVersion::V2,
-        type_: VRRPv2Type::VRRPv2Advertisement,
-        virtual_router_id: 1,
-        priority: 100,
-        count_ip_addrs: 1,
-        auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
-        checksum: 47698,
-        advertisement_interval: 1,
-        ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
-    };
-    assert_eq!(got, expected);
+    Ok(vrrpv2)
 }
 
 #[test]
 fn test_incomplete_bytes() {
     let bytes = [0x21, 0x01];
-    let got = from_bytes(&bytes);
-    assert_eq!(got.is_err(), true);
-    assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError));
+    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::ParseError));
 }
 
 #[test]
-fn test_invalid_version_type() {
+fn test_invalid_version() {
     let bytes = [
-        0x00, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
-        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+        0x31, 0x2a, 0x64, 0x1, 0x0, 0x1, 0xaa, 0x29, 0xc0, 0xa8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0,
+        0x0, 0x0, 0x0,
     ];
     let got = from_bytes(&bytes);
     assert_eq!(got.is_err(), true);
-    assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError));
+    assert_eq!(got.err(), Some(VRRPv2Error::InvalidVersion));
+}
+
+#[test]
+fn test_invalid_type() {
+    let bytes = [
+        0x20, 0x1, 0x2a, 0x0, 0x0, 0x1, 0xb5, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
+        0x0, 0x0, 0x0,
+    ];
+    let got = from_bytes(&bytes);
+    assert_eq!(got.is_err(), true);
+    assert_eq!(got.err(), Some(VRRPv2Error::InvalidType));
 }
 
 #[test]
@@ -183,9 +206,7 @@
         0x21, 0x01, 0x64, 0x01, 0x03, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
     ];
-    let got = from_bytes(&bytes);
-    assert_eq!(got.is_err(), true);
-    assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError));
+    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidAuthType));
 }
 
 #[test]
@@ -194,7 +215,5 @@
         0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xbb, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
         0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
     ];
-    let got = from_bytes(&bytes);
-    assert_eq!(got.is_err(), true);
-    assert_eq!(got.err(), Some(VRRPv2Error::VRRPv2ParseError));
+    assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidChecksum));
 }