Implementing ToSql and FromSql for the encrypted value type to write to a BLOB type in sqlite

master
Tom Alexander 5 years ago
parent 7276f84233
commit 542dedec8d

@ -6,8 +6,17 @@ use crypto::scrypt::{self, ScryptParams};
use crypto::sha2::Sha256;
use rand::rngs::OsRng;
use rand::Rng;
use rusqlite::types::FromSql;
use rusqlite::types::FromSqlResult;
use rusqlite::types::ToSqlOutput;
use rusqlite::types::ToSqlOutput::Owned;
use rusqlite::types::Value::Blob;
use rusqlite::types::ValueRef;
use rusqlite::ToSql;
use rustc_serialize::base64;
use rustc_serialize::base64::{FromBase64, ToBase64};
use std::convert::TryFrom;
use std::convert::TryInto;
use std::io;
pub struct EncryptedValue {
@ -16,6 +25,66 @@ pub struct EncryptedValue {
pub mac: MacResult,
}
impl ToSql for EncryptedValue {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
// Format:
// 8 bytes: length of mac
// n bytes: mac
// 8 bytes: length of iv
// n bytes: iv
// 8 bytes: length of cihpertext
// n bytes: ciphertext
let length_of_ciphertext: u64 = self.ciphertext.len().try_into().unwrap();
let length_of_iv: u64 = self.iv.len().try_into().unwrap();
let mac_bytes = self.mac.code();
let length_of_mac: u64 = mac_bytes.len().try_into().unwrap();
let full_length = (8 * 3) + length_of_mac + length_of_iv + length_of_ciphertext;
let mut out: Vec<u8> = Vec::with_capacity(full_length.try_into().unwrap());
out.extend(&length_of_mac.to_le_bytes());
out.extend(mac_bytes);
out.extend(&length_of_iv.to_le_bytes());
out.extend(&self.iv);
out.extend(&length_of_ciphertext.to_le_bytes());
out.extend(&self.ciphertext);
Ok(Owned(Blob(out)))
}
}
impl FromSql for EncryptedValue {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob().unwrap();
let length_of_mac: u64 =
u64::from_le_bytes(bytes[0..8].try_into().expect("Invalid number of bytes"));
let length_of_iv: u64 = u64::from_le_bytes(
bytes[usize::try_from(8 + length_of_mac).unwrap()
..usize::try_from(16 + length_of_mac).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
);
let length_of_ciphertext: u64 = u64::from_le_bytes(
bytes[usize::try_from(16 + length_of_mac + length_of_iv).unwrap()
..usize::try_from(24 + length_of_mac + length_of_iv).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
);
Ok(EncryptedValue {
ciphertext: bytes[usize::try_from(24 + length_of_mac + length_of_iv).unwrap()
..usize::try_from(24 + length_of_mac + length_of_iv + length_of_ciphertext)
.unwrap()]
.to_vec(),
iv: bytes[usize::try_from(16 + length_of_mac).unwrap()
..usize::try_from(16 + length_of_mac + length_of_iv).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
mac: MacResult::new(&bytes[8..usize::try_from(8 + length_of_mac).unwrap()]),
})
}
}
pub fn get_master_key(db_conn: &db::DbHandle, master_password: &str) -> io::Result<[u8; 32]> {
let scrypt_params: ScryptParams = ScryptParams::new(12, 16, 2);
let salt: Vec<u8> = get_salt(db_conn)?;
@ -77,3 +146,43 @@ pub fn encrypt_value(value: &str, master_key: [u8; 32]) -> EncryptedValue {
mac: hmac.result(),
}
}
#[cfg(test)]
mod tests {
use crate::crypt::{encrypt_value, EncryptedValue};
use rusqlite::{Connection, NO_PARAMS};
#[test]
fn test_encrypted_value_round_trip() {
let db = Connection::open_in_memory().expect("Failed to open DB");
db.execute_batch("CREATE TABLE test (content BLOB);")
.expect("Failed to create table");
let master_key: [u8; 32] = [0u8; 32];
let encrypted_value = encrypt_value("hunter2", master_key);
db.execute(
"INSERT INTO test (content) VALUES ($1)",
&[&encrypted_value],
)
.expect("Failed to insert value into DB");
let mut stmt = db
.prepare("SELECT * FROM test")
.expect("Failed to prepare statement");
let rows: Vec<Result<EncryptedValue, rusqlite::Error>> = stmt
.query_map(NO_PARAMS, |row| {
let val: EncryptedValue = row.get(0).expect("Failed to get element from row");
Ok(val)
})
.expect("Failed to get rows")
.collect();
assert_eq!(rows.len(), 1);
for returned_result in rows {
let returned_value: EncryptedValue = returned_result.expect("Bad value returned");
assert_eq!(returned_value.ciphertext, encrypted_value.ciphertext);
assert_eq!(returned_value.iv, encrypted_value.iv);
assert_eq!(returned_value.mac.code(), encrypted_value.mac.code());
}
}
}

@ -21,6 +21,7 @@ Usage:
foil set [--db=<db>]
foil get [--db=<db>]
foil list [--db=<db>]
foil transfer [--db=<db>]
foil generate <spec>
foil (-h | --help)
@ -36,6 +37,7 @@ struct Args {
cmd_get: bool,
cmd_list: bool,
cmd_generate: bool,
cmd_transfer: bool,
flag_db: Option<String>,
arg_spec: Option<String>,
}
@ -140,6 +142,8 @@ fn set(mut db_conn: db::DbHandle, master_key: [u8; 32]) {
println!("Successfully added password");
}
fn transfer(mut db_conn: db::DbHandle, master_key: [u8; 32]) {}
fn main() -> Result<(), Box<dyn Error>> {
pretty_env_logger::init();
let args: Args = Docopt::new(USAGE)
@ -162,6 +166,8 @@ fn main() -> Result<(), Box<dyn Error>> {
get(db_conn, master_key);
} else if args.cmd_list {
list(db_conn, master_key);
} else if args.cmd_transfer {
transfer(db_conn, master_key);
}
Ok(())

Loading…
Cancel
Save