diff --git a/src/crypt.rs b/src/crypt.rs index fa9fdb7..db6f592 100644 --- a/src/crypt.rs +++ b/src/crypt.rs @@ -17,6 +17,7 @@ use rustc_serialize::base64; use rustc_serialize::base64::{FromBase64, ToBase64}; use std::convert::TryFrom; use std::convert::TryInto; +use std::error::Error; use std::io; pub struct EncryptedValue { @@ -85,6 +86,29 @@ impl FromSql for EncryptedValue { } } +impl EncryptedValue { + pub fn decrypt_to_bytes(&self, master_key: [u8; 32]) -> Vec { + let mut hmac = Hmac::new(Sha256::new(), &master_key); + hmac.input(&self.ciphertext); + if hmac.result() != self.mac { + panic!("Mac did not match, corrupted data"); + } + + let mut cipher = aes::ctr(KeySize::KeySize256, &master_key, &self.iv); + let mut output: Vec = vec![0; self.ciphertext.len()]; + cipher.process(&self.ciphertext, output.as_mut_slice()); + output + } + + pub fn decrypt_to_string( + &self, + master_key: [u8; 32], + ) -> Result { + let decrypted_bytes = self.decrypt_to_bytes(master_key); + String::from_utf8(decrypted_bytes) + } +} + pub fn get_master_key(db_conn: &db::DbHandle, master_password: &str) -> io::Result<[u8; 32]> { let scrypt_params: ScryptParams = ScryptParams::new(12, 16, 2); let salt: Vec = get_salt(db_conn)?; diff --git a/src/db.rs b/src/db.rs index 7bd6293..a854a61 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,6 +1,6 @@ use super::crypt; use crate::crypt::EncryptedValue; -use rusqlite::{Connection, NO_PARAMS}; +use rusqlite::{params, Connection, NO_PARAMS}; use rustc_serialize::base64; use rustc_serialize::base64::{FromBase64, ToBase64}; use std::error::Error; @@ -26,15 +26,15 @@ pub struct Account { pub password: String, } -#[derive(Debug)] pub struct DbNamespace { pub id: i64, - pub name: String, + pub name: EncryptedValue, } #[derive(Debug)] pub struct DbNote { pub id: i64, + pub namespace: String, pub category: String, pub title: String, pub value: String, @@ -53,6 +53,44 @@ impl DbHandle { DbHandle { conn: conn } } + pub fn get_namespace_id( + &mut self, + name: &str, + master_key: [u8; 32], + ) -> Result> { + { + let mut stmt = self + .conn + .prepare("SELECT id, name FROM namespaces") + .unwrap(); + let rows = stmt.query_map(params![], |row| { + Ok(DbNamespace { + id: row.get(0)?, + name: row.get(1)?, + }) + })?; + + for row_result in rows { + let row: DbNamespace = row_result?; + let row_name: String = row.name.decrypt_to_string(master_key)?; + if name == row_name { + return Ok(row.id); + } + } + } + + let new_namespace = crypt::encrypt_value(name, master_key); + let tx = self.conn.transaction().unwrap(); + tx.execute( + "INSERT INTO namespaces (name) VALUES ($1)", + &[&new_namespace], + ) + .unwrap(); + let rowid: i64 = tx.last_insert_rowid(); + let _ = tx.commit().unwrap(); + Ok(rowid) + } + pub fn get_db_property(&self, name: &str) -> Result, Box> { let mut stmt = self .conn