diff --git a/zcash_primitives/src/serialize.rs b/zcash_primitives/src/serialize.rs index f142943..8312312 100644 --- a/zcash_primitives/src/serialize.rs +++ b/zcash_primitives/src/serialize.rs @@ -82,6 +82,37 @@ impl Vector { } } +pub struct Optional; + +impl Optional { + pub fn read(mut reader: R, func: F) -> io::Result> + where + F: Fn(&mut R) -> io::Result, + { + match reader.read_u8()? { + 0 => Ok(None), + 1 => Ok(Some(func(&mut reader)?)), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "non-canonical Option", + )), + } + } + + pub fn write(mut writer: W, val: &Option, func: F) -> io::Result<()> + where + F: Fn(&mut W, &T) -> io::Result<()>, + { + match val { + None => writer.write_u8(0), + Some(e) => { + writer.write_u8(1)?; + func(&mut writer, e) + } + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -153,4 +184,46 @@ mod tests { eval!(vec![7; 260], expected); } } + + #[test] + fn optional() { + macro_rules! eval { + ($value:expr, $expected:expr, $write:expr, $read:expr) => { + let mut data = vec![]; + Optional::write(&mut data, &$value, $write).unwrap(); + assert_eq!(&data[..], &$expected[..]); + match Optional::read(&data[..], $read) { + Ok(v) => assert_eq!(v, $value), + Err(e) => panic!("Unexpected error: {:?}", e), + } + }; + } + + macro_rules! eval_u8 { + ($value:expr, $expected:expr) => { + eval!($value, $expected, |w, e| w.write_u8(*e), |r| r.read_u8()) + }; + } + + macro_rules! eval_vec { + ($value:expr, $expected:expr) => { + eval!( + $value, + $expected, + |w, v| Vector::write(w, v, |w, e| w.write_u8(*e)), + |r| Vector::read(r, |r| r.read_u8()) + ) + }; + } + + eval_u8!(None, [0]); + eval_u8!(Some(0), [1, 0]); + eval_u8!(Some(1), [1, 1]); + eval_u8!(Some(5), [1, 5]); + + eval_vec!(Some(vec![]), [1, 0]); + eval_vec!(Some(vec![0]), [1, 1, 0]); + eval_vec!(Some(vec![1]), [1, 1, 1]); + eval_vec!(Some(vec![5; 8]), [1, 8, 5, 5, 5, 5, 5, 5, 5, 5]); + } }