From 542dedec8deacf238640d9a9855d07caf6744c8b Mon Sep 17 00:00:00 2001 From: Tom Alexander Date: Sun, 2 Jun 2019 16:52:59 -0400 Subject: [PATCH] Implementing ToSql and FromSql for the encrypted value type to write to a BLOB type in sqlite --- src/crypt.rs | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 6 +++ 2 files changed, 115 insertions(+) diff --git a/src/crypt.rs b/src/crypt.rs index 5b9465c..0343f5f 100644 --- a/src/crypt.rs +++ b/src/crypt.rs @@ -6,8 +6,17 @@ use crypto::scrypt::{self, ScryptParams}; use crypto::sha2::Sha256; use rand::rngs::OsRng; use rand::Rng; +use rusqlite::types::FromSql; +use rusqlite::types::FromSqlResult; +use rusqlite::types::ToSqlOutput; +use rusqlite::types::ToSqlOutput::Owned; +use rusqlite::types::Value::Blob; +use rusqlite::types::ValueRef; +use rusqlite::ToSql; use rustc_serialize::base64; use rustc_serialize::base64::{FromBase64, ToBase64}; +use std::convert::TryFrom; +use std::convert::TryInto; use std::io; pub struct EncryptedValue { @@ -16,6 +25,66 @@ pub struct EncryptedValue { pub mac: MacResult, } +impl ToSql for EncryptedValue { + fn to_sql(&self) -> rusqlite::Result { + // Format: + // 8 bytes: length of mac + // n bytes: mac + // 8 bytes: length of iv + // n bytes: iv + // 8 bytes: length of cihpertext + // n bytes: ciphertext + let length_of_ciphertext: u64 = self.ciphertext.len().try_into().unwrap(); + let length_of_iv: u64 = self.iv.len().try_into().unwrap(); + let mac_bytes = self.mac.code(); + let length_of_mac: u64 = mac_bytes.len().try_into().unwrap(); + let full_length = (8 * 3) + length_of_mac + length_of_iv + length_of_ciphertext; + + let mut out: Vec = Vec::with_capacity(full_length.try_into().unwrap()); + out.extend(&length_of_mac.to_le_bytes()); + out.extend(mac_bytes); + out.extend(&length_of_iv.to_le_bytes()); + out.extend(&self.iv); + out.extend(&length_of_ciphertext.to_le_bytes()); + out.extend(&self.ciphertext); + + Ok(Owned(Blob(out))) + } +} + +impl FromSql for EncryptedValue { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob().unwrap(); + + let length_of_mac: u64 = + u64::from_le_bytes(bytes[0..8].try_into().expect("Invalid number of bytes")); + let length_of_iv: u64 = u64::from_le_bytes( + bytes[usize::try_from(8 + length_of_mac).unwrap() + ..usize::try_from(16 + length_of_mac).unwrap()] + .try_into() + .expect("Invalid number of bytes"), + ); + let length_of_ciphertext: u64 = u64::from_le_bytes( + bytes[usize::try_from(16 + length_of_mac + length_of_iv).unwrap() + ..usize::try_from(24 + length_of_mac + length_of_iv).unwrap()] + .try_into() + .expect("Invalid number of bytes"), + ); + + Ok(EncryptedValue { + ciphertext: bytes[usize::try_from(24 + length_of_mac + length_of_iv).unwrap() + ..usize::try_from(24 + length_of_mac + length_of_iv + length_of_ciphertext) + .unwrap()] + .to_vec(), + iv: bytes[usize::try_from(16 + length_of_mac).unwrap() + ..usize::try_from(16 + length_of_mac + length_of_iv).unwrap()] + .try_into() + .expect("Invalid number of bytes"), + mac: MacResult::new(&bytes[8..usize::try_from(8 + length_of_mac).unwrap()]), + }) + } +} + pub fn get_master_key(db_conn: &db::DbHandle, master_password: &str) -> io::Result<[u8; 32]> { let scrypt_params: ScryptParams = ScryptParams::new(12, 16, 2); let salt: Vec = get_salt(db_conn)?; @@ -77,3 +146,43 @@ pub fn encrypt_value(value: &str, master_key: [u8; 32]) -> EncryptedValue { mac: hmac.result(), } } + +#[cfg(test)] +mod tests { + use crate::crypt::{encrypt_value, EncryptedValue}; + use rusqlite::{Connection, NO_PARAMS}; + + #[test] + fn test_encrypted_value_round_trip() { + let db = Connection::open_in_memory().expect("Failed to open DB"); + db.execute_batch("CREATE TABLE test (content BLOB);") + .expect("Failed to create table"); + let master_key: [u8; 32] = [0u8; 32]; + let encrypted_value = encrypt_value("hunter2", master_key); + db.execute( + "INSERT INTO test (content) VALUES ($1)", + &[&encrypted_value], + ) + .expect("Failed to insert value into DB"); + + let mut stmt = db + .prepare("SELECT * FROM test") + .expect("Failed to prepare statement"); + let rows: Vec> = stmt + .query_map(NO_PARAMS, |row| { + let val: EncryptedValue = row.get(0).expect("Failed to get element from row"); + Ok(val) + }) + .expect("Failed to get rows") + .collect(); + + assert_eq!(rows.len(), 1); + + for returned_result in rows { + let returned_value: EncryptedValue = returned_result.expect("Bad value returned"); + assert_eq!(returned_value.ciphertext, encrypted_value.ciphertext); + assert_eq!(returned_value.iv, encrypted_value.iv); + assert_eq!(returned_value.mac.code(), encrypted_value.mac.code()); + } + } +} diff --git a/src/main.rs b/src/main.rs index 8c2fb31..09a73e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ Usage: foil set [--db=] foil get [--db=] foil list [--db=] + foil transfer [--db=] foil generate foil (-h | --help) @@ -36,6 +37,7 @@ struct Args { cmd_get: bool, cmd_list: bool, cmd_generate: bool, + cmd_transfer: bool, flag_db: Option, arg_spec: Option, } @@ -140,6 +142,8 @@ fn set(mut db_conn: db::DbHandle, master_key: [u8; 32]) { println!("Successfully added password"); } +fn transfer(mut db_conn: db::DbHandle, master_key: [u8; 32]) {} + fn main() -> Result<(), Box> { pretty_env_logger::init(); let args: Args = Docopt::new(USAGE) @@ -162,6 +166,8 @@ fn main() -> Result<(), Box> { get(db_conn, master_key); } else if args.cmd_list { list(db_conn, master_key); + } else if args.cmd_transfer { + transfer(db_conn, master_key); } Ok(())