diff --git a/lib/src/lightclient.rs b/lib/src/lightclient.rs index 1121e7f..f3f2bbb 100644 --- a/lib/src/lightclient.rs +++ b/lib/src/lightclient.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, RwLock, Mutex, mpsc::channel}; use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; use std::path::{Path, PathBuf}; use std::fs::File; -use std::collections::HashMap; +use std::collections::{HashSet, HashMap}; use std::cmp::{max, min}; use std::io; use std::io::prelude::*; @@ -821,22 +821,35 @@ impl LightClient { let anchor_height: i32 = self.wallet.read().unwrap().get_anchor_height() as i32; { - // Collect Sapling notes let wallet = self.wallet.read().unwrap(); + + // First, collect all extfvk's that are spendable (i.e., we have the private key) + let spendable_address: HashSet = wallet.get_all_zaddresses().iter() + .filter(|address| wallet.have_spending_key_for_zaddress(address)) + .map(|address| address.clone()) + .collect(); + + // Collect Sapling notes wallet.txs.read().unwrap().iter() .flat_map( |(txid, wtx)| { + let spendable_address = spendable_address.clone(); wtx.notes.iter().filter_map(move |nd| if !all_notes && nd.spent.is_some() { None } else { + let address = LightWallet::note_address(self.config.hrp_sapling_address(), nd); + let spendable = address.is_some() && + spendable_address.contains(&address.clone().unwrap()) && + wtx.block <= anchor_height && nd.spent.is_none() && nd.unconfirmed_spent.is_none(); + Some(object!{ "created_in_block" => wtx.block, "datetime" => wtx.datetime, "created_in_txid" => format!("{}", txid), "value" => nd.note.value, "is_change" => nd.is_change, - "address" => LightWallet::note_address(self.config.hrp_sapling_address(), nd), - "spendable" => wtx.block <= anchor_height && nd.spent.is_none() && nd.unconfirmed_spent.is_none(), + "address" => address, + "spendable" => spendable, "spent" => nd.spent.map(|spent_txid| format!("{}", spent_txid)), "spent_at_height" => nd.spent_at_height.map(|h| format!("{}", h)), "unconfirmed_spent" => nd.unconfirmed_spent.map(|spent_txid| format!("{}", spent_txid)), diff --git a/lib/src/lightwallet.rs b/lib/src/lightwallet.rs index 9212074..ba5762a 100644 --- a/lib/src/lightwallet.rs +++ b/lib/src/lightwallet.rs @@ -1024,10 +1024,7 @@ impl LightWallet { .filter(|nd| nd.spent.is_none() && nd.unconfirmed_spent.is_none()) .filter(|nd| { // Check to see if we have this note's spending key. - match self.zkeys.read().unwrap().iter().find(|zk| zk.extfvk == nd.extfvk) { - Some(zk) => zk.keytype == WalletZKeyType::HdKey || zk.keytype == WalletZKeyType::ImportedSpendingKey, - _ => false - } + self.have_spendingkey_for_extfvk(&nd.extfvk) }) .filter(|nd| { // TODO, this whole section is shared with verified_balance. Refactor it. match addr.clone() { @@ -1048,6 +1045,22 @@ impl LightWallet { .sum::() } + pub fn have_spendingkey_for_extfvk(&self, extfvk: &ExtendedFullViewingKey) -> bool { + match self.zkeys.read().unwrap().iter().find(|zk| zk.extfvk == *extfvk) { + None => false, + Some(zk) => zk.have_spending_key() + } + } + + pub fn have_spending_key_for_zaddress(&self, address: &String) -> bool { + match self.zkeys.read().unwrap().iter() + .find(|zk| encode_payment_address(self.config.hrp_sapling_address(), &zk.zaddress) == *address) + { + None => false, + Some(zk) => zk.have_spending_key() + } + } + fn add_toutput_to_wtx(&self, height: i32, timestamp: u64, txid: &TxId, vout: &TxOut, n: u64) { let mut txs = self.txs.write().unwrap(); diff --git a/lib/src/lightwallet/tests.rs b/lib/src/lightwallet/tests.rs index 0f05f8c..ef16d4e 100644 --- a/lib/src/lightwallet/tests.rs +++ b/lib/src/lightwallet/tests.rs @@ -813,6 +813,8 @@ fn test_z_spend_to_z() { let sent_tx = Transaction::read(&raw_tx[..]).unwrap(); let sent_txid = sent_tx.txid(); + assert_eq!(wallet.have_spending_key_for_zaddress(wallet.get_all_zaddresses().get(0).unwrap()), true); + // Now, the note should be unconfirmed spent { let txs = wallet.txs.read().unwrap(); @@ -2363,6 +2365,8 @@ fn test_import_vk() { assert_eq!(wallet.zkeys.read().unwrap()[1].keytype, WalletZKeyType::ImportedViewKey); assert_eq!(wallet.zkeys.read().unwrap()[1].hdkey_num, None); + assert_eq!(wallet.have_spending_key_for_zaddress(&zaddr.to_string()), false); + // Importing it again should fail assert!(wallet.add_imported_sk(viewkey.to_string(), 0).starts_with("Error")); assert_eq!(wallet.get_all_zaddresses().len(), 2); @@ -2411,6 +2415,8 @@ fn test_import_sk_upgrade_vk() { assert_eq!(wallet.zkeys.read().unwrap()[1].hdkey_num, None); assert!(wallet.zkeys.read().unwrap()[1].extsk.is_none()); + assert_eq!(wallet.have_spending_key_for_zaddress(&zaddr.to_string()), false); + // Importing it again should fail because it already exists assert!(wallet.add_imported_sk(viewkey.to_string(), 0).starts_with("Error")); assert_eq!(wallet.get_all_zaddresses().len(), 2); @@ -2431,6 +2437,8 @@ fn test_import_sk_upgrade_vk() { assert_eq!(wallet.zkeys.read().unwrap()[1].keytype, WalletZKeyType::ImportedSpendingKey); assert_eq!(wallet.zkeys.read().unwrap()[1].hdkey_num, None); assert!(wallet.zkeys.read().unwrap()[1].extsk.is_some()); + + assert_eq!(wallet.have_spending_key_for_zaddress(&zaddr.to_string()), true); } #[test] @@ -2500,9 +2508,13 @@ fn test_encrypted_zreceive() { let (_, raw_tx) = wallet.send_to_address(branch_id, &ss, &so, vec![(&ext_address, AMOUNT_SENT, Some(outgoing_memo.clone()))], |_| Ok(' '.to_string())).unwrap(); + assert_eq!(wallet.have_spending_key_for_zaddress(wallet.get_all_zaddresses().get(0).unwrap()), true); + // Now that we have the transaction, we'll encrypt the wallet wallet.encrypt(password.clone()).unwrap(); + assert_eq!(wallet.have_spending_key_for_zaddress(wallet.get_all_zaddresses().get(0).unwrap()), true); + // Scan the tx and make sure it gets added let sent_tx = Transaction::read(&raw_tx[..]).unwrap(); let sent_txid = sent_tx.txid(); diff --git a/lib/src/lightwallet/walletzkey.rs b/lib/src/lightwallet/walletzkey.rs index d6c3a1f..3c97c49 100644 --- a/lib/src/lightwallet/walletzkey.rs +++ b/lib/src/lightwallet/walletzkey.rs @@ -102,6 +102,10 @@ impl WalletZKey { } } + pub fn have_spending_key(&self) -> bool { + self.extsk.is_some() || self.enc_key.is_some() || self.hdkey_num.is_some() + } + fn serialized_version() -> u8 { return 1; }