Rectify checksum calculation.
authorSunil Nimmagadda <sunil@nimmagadda.net>
Tue, 02 Jan 2024 14:52:44 +0530
changeset 17 90d097c9ea62
parent 16 8c8be538d0e6
child 18 2470a15711b1
Rectify checksum calculation. Add some tests found in RFC1071 as examples.
src/vrrpv2.rs
--- a/src/vrrpv2.rs	Thu Dec 28 23:17:24 2023 +0530
+++ b/src/vrrpv2.rs	Tue Jan 02 14:52:44 2024 +0530
@@ -115,58 +115,28 @@
     })
 }
 
-// Nightly has a nicer array_chunks API to express it more succinctly.
-// let mut chunks = bytes.array_chunks(2);
-// let mut sum = chunks.map(u16::from_ne_bytes).map(|b| b as u32).sum::<u32>();
-// // handle the remainder
-// if let Some([b]) = chunks.remainder() {
-//     sum += *b as u32
-// }
-
-// Shadowing can be used to avoid `mut`...
-// let sum =...;
-// let sum = (sum & 0xffff) + (sum >> 16);
-// let sum = (sum & 0xffff) + (sum >> 16);
-// manually un-rolling while loop since it's needed atmost twice for an u32.
-fn validate_checksum(bytes: &[u8]) -> bool {
-    let mut sum: u32 = bytes.chunks(2).fold(0, |acc: u32, x| {
-        acc + u32::from(u16::from_ne_bytes(x.try_into().unwrap()))
-    });
+// nightly has as_chunks that allows for a nicer code...
+// let (chunks, remainder) = bytes.as_chunks(2);
+// fold chunks and remainder without an if.
+fn checksum(bytes: &[u8]) -> u16 {
+    let mut sum: u32 = 0;
+    for chunk in bytes.chunks(2) {
+        // Left over byte if any
+        if chunk.len() == 1 {
+            sum += u32::from(chunk[0]);
+        } else {
+            sum += u32::from(u16::from_be_bytes(chunk.try_into().unwrap()));
+        }
+    }
     while (sum >> 16) > 0 {
         sum = (sum & 0xffff) + (sum >> 16);
     }
-    let checksum = !(sum as u16);
-    checksum == 0
+    !(sum as u16)
 }
 
-/// Parse and validate a byte array to construct a VRRPv2 struct.
-///
-/// # Examples
-///
-/// ```
-/// use vrrpd::vrrpv2::VRRPv2;
-/// use vrrpd::vrrpv2::VRRPv2AuthType;
-/// use vrrpd::vrrpv2::from_bytes;
-/// use std::net::Ipv4Addr;
-///
-/// let bytes = [
-///    0x21, 0x01, 0x64, 0x01, 0x00, 0x01, 0xba, 0x52, 0xc0, 0xa8, 0x00, 0x01, 0x00, 0x00, 0x00,
-///    0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-/// ];
-/// let expected = VRRPv2 {
-///     virtual_router_id: 1,
-///     priority: 100,
-///     count_ip_addrs: 1,
-///     auth_type: VRRPv2AuthType::VRRPv2AuthNoAuth,
-///     checksum: 47698,
-///     advertisement_interval: 1,
-///     ip_addrs: vec![Ipv4Addr::from([192, 168, 0, 1])],
-/// };
-/// assert_eq!(from_bytes(&bytes), Ok(expected));
-/// ```
 pub fn from_bytes(bytes: &[u8]) -> Result<VRRPv2, VRRPv2Error> {
     let vrrpv2 = parse(bytes)?;
-    if !validate_checksum(bytes) {
+    if checksum(bytes) != 0 {
         return Err(VRRPv2Error::InvalidChecksum);
     }
     Ok(vrrpv2)
@@ -217,3 +187,27 @@
     ];
     assert_eq!(from_bytes(&bytes), Err(VRRPv2Error::InvalidChecksum));
 }
+
+#[test]
+fn test_checksum() {
+    let bytes = [0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7];
+    assert_eq!(checksum(&bytes), 0x220d);
+}
+
+#[test]
+fn test_checksum_singlebyte() {
+    let bytes = [0; 1];
+    assert_eq!(checksum(&bytes), 0xffff);
+}
+
+#[test]
+fn test_checksum_twobytes() {
+    let bytes = [0x00, 0xff];
+    assert_eq!(checksum(&bytes), 0xff00);
+}
+
+#[test]
+fn test_checksum_another() {
+    let bytes = [0xe3, 0x4f, 0x23, 0x96, 0x44, 0x27, 0x99, 0xf3];
+    assert_eq!(checksum(&bytes), 0x1aff);
+}