diff --git a/librustzcash/src/rustzcash.rs b/librustzcash/src/rustzcash.rs index 0697272..7e7c1a0 100644 --- a/librustzcash/src/rustzcash.rs +++ b/librustzcash/src/rustzcash.rs @@ -143,9 +143,13 @@ pub extern "system" fn librustzcash_init_zksnark_params( let output_path = Path::new(OsStr::from_bytes(unsafe { slice::from_raw_parts(output_path, output_path_len) })); - let sprout_path = Path::new(OsStr::from_bytes(unsafe { - slice::from_raw_parts(sprout_path, sprout_path_len) - })); + let sprout_path = if sprout_path.is_null() { + None + } else { + Some(Path::new(OsStr::from_bytes(unsafe { + slice::from_raw_parts(sprout_path, sprout_path_len) + }))) + }; init_zksnark_params( spend_path, @@ -174,8 +178,13 @@ pub extern "system" fn librustzcash_init_zksnark_params( OsString::from_wide(unsafe { slice::from_raw_parts(spend_path, spend_path_len) }); let output_path = OsString::from_wide(unsafe { slice::from_raw_parts(output_path, output_path_len) }); - let sprout_path = - OsString::from_wide(unsafe { slice::from_raw_parts(sprout_path, sprout_path_len) }); + let sprout_path = if sprout_path.is_null() { + None + } else { + Some(OsStr::from_wide(unsafe { + slice::from_raw_parts(sprout_path, sprout_path_len) + })) + }; init_zksnark_params( Path::new(&spend_path), @@ -192,7 +201,7 @@ fn init_zksnark_params( spend_hash: *const c_char, output_path: &Path, output_hash: *const c_char, - sprout_path: &Path, + sprout_path: Option<&Path>, sprout_hash: *const c_char, ) { // Initialize jubjub parameters here @@ -206,9 +215,15 @@ fn init_zksnark_params( .to_str() .expect("hash should be a valid string"); - let sprout_hash = unsafe { CStr::from_ptr(sprout_hash) } - .to_str() - .expect("hash should be a valid string"); + let sprout_hash = if sprout_path.is_none() { + None + } else { + Some( + unsafe { CStr::from_ptr(sprout_hash) } + .to_str() + .expect("hash should be a valid string"), + ) + }; // Load params let (spend_params, spend_vk, output_params, output_vk, sprout_vk) = load_parameters( @@ -216,8 +231,8 @@ fn init_zksnark_params( spend_hash, output_path, output_hash, - Some(sprout_path), - Some(sprout_hash), + sprout_path, + sprout_hash, ); // Caller is responsible for calling this function once, so @@ -225,11 +240,11 @@ fn init_zksnark_params( unsafe { SAPLING_SPEND_PARAMS = Some(spend_params); SAPLING_OUTPUT_PARAMS = Some(output_params); - SPROUT_GROTH16_PARAMS_PATH = Some(sprout_path.to_owned()); + SPROUT_GROTH16_PARAMS_PATH = sprout_path.map(|p| p.to_owned()); SAPLING_SPEND_VK = Some(spend_vk); SAPLING_OUTPUT_VK = Some(output_vk); - SPROUT_GROTH16_VK = Some(sprout_vk.unwrap()); + SPROUT_GROTH16_VK = sprout_vk; } }