You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
7.4 KiB
Rust

use super::db;
use crypto::aes::{self, KeySize};
use crypto::hmac::Hmac;
use crypto::mac::{Mac, MacResult};
use crypto::scrypt::{self, ScryptParams};
use crypto::sha2::Sha256;
use rand::rngs::OsRng;
use rand::Rng;
use rusqlite::types::FromSql;
use rusqlite::types::FromSqlResult;
use rusqlite::types::ToSqlOutput;
use rusqlite::types::ToSqlOutput::Owned;
use rusqlite::types::Value::Blob;
use rusqlite::types::ValueRef;
use rusqlite::ToSql;
use rustc_serialize::base64;
use rustc_serialize::base64::{FromBase64, ToBase64};
use std::convert::TryFrom;
use std::convert::TryInto;
use std::io;
pub struct EncryptedValue {
pub ciphertext: Vec<u8>,
pub iv: [u8; 32],
pub mac: MacResult,
}
impl ToSql for EncryptedValue {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
// Format:
// 8 bytes: length of mac
// n bytes: mac
// 8 bytes: length of iv
// n bytes: iv
// 8 bytes: length of cihpertext
// n bytes: ciphertext
let length_of_ciphertext: u64 = self.ciphertext.len().try_into().unwrap();
let length_of_iv: u64 = self.iv.len().try_into().unwrap();
let mac_bytes = self.mac.code();
let length_of_mac: u64 = mac_bytes.len().try_into().unwrap();
let full_length = (8 * 3) + length_of_mac + length_of_iv + length_of_ciphertext;
let mut out: Vec<u8> = Vec::with_capacity(full_length.try_into().unwrap());
out.extend(&length_of_mac.to_le_bytes());
out.extend(mac_bytes);
out.extend(&length_of_iv.to_le_bytes());
out.extend(&self.iv);
out.extend(&length_of_ciphertext.to_le_bytes());
out.extend(&self.ciphertext);
Ok(Owned(Blob(out)))
}
}
impl FromSql for EncryptedValue {
fn column_result(value: ValueRef) -> FromSqlResult<Self> {
let bytes = value.as_blob().unwrap();
let length_of_mac: u64 =
u64::from_le_bytes(bytes[0..8].try_into().expect("Invalid number of bytes"));
let length_of_iv: u64 = u64::from_le_bytes(
bytes[usize::try_from(8 + length_of_mac).unwrap()
..usize::try_from(16 + length_of_mac).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
);
let length_of_ciphertext: u64 = u64::from_le_bytes(
bytes[usize::try_from(16 + length_of_mac + length_of_iv).unwrap()
..usize::try_from(24 + length_of_mac + length_of_iv).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
);
Ok(EncryptedValue {
ciphertext: bytes[usize::try_from(24 + length_of_mac + length_of_iv).unwrap()
..usize::try_from(24 + length_of_mac + length_of_iv + length_of_ciphertext)
.unwrap()]
.to_vec(),
iv: bytes[usize::try_from(16 + length_of_mac).unwrap()
..usize::try_from(16 + length_of_mac + length_of_iv).unwrap()]
.try_into()
.expect("Invalid number of bytes"),
mac: MacResult::new(&bytes[8..usize::try_from(8 + length_of_mac).unwrap()]),
})
}
}
impl EncryptedValue {
pub fn decrypt_to_bytes(&self, master_key: [u8; 32]) -> Vec<u8> {
let mut hmac = Hmac::new(Sha256::new(), &master_key);
hmac.input(&self.ciphertext);
if hmac.result() != self.mac {
panic!("Mac did not match, corrupted data");
}
let mut cipher = aes::ctr(KeySize::KeySize256, &master_key, &self.iv);
let mut output: Vec<u8> = vec![0; self.ciphertext.len()];
cipher.process(&self.ciphertext, output.as_mut_slice());
output
}
pub fn decrypt_to_string(
&self,
master_key: [u8; 32],
) -> Result<String, std::string::FromUtf8Error> {
let decrypted_bytes = self.decrypt_to_bytes(master_key);
String::from_utf8(decrypted_bytes)
}
}
pub fn get_master_key(db_conn: &db::DbHandle, master_password: &str) -> io::Result<[u8; 32]> {
let scrypt_params: ScryptParams = ScryptParams::new(12, 16, 2);
let salt: Vec<u8> = get_salt(db_conn)?;
// 256 bit derived key
let mut derived_key = [0u8; 32];
scrypt::scrypt(
master_password.as_bytes(),
&*salt,
&scrypt_params,
&mut derived_key,
);
Ok(derived_key)
}
fn get_salt(db_conn: &db::DbHandle) -> io::Result<Vec<u8>> {
let existing_salt: Option<String> = db_conn
.get_db_property("salt")
.expect("There was a problem reading from the db");
match existing_salt {
Some(salt) => Ok(salt.from_base64().unwrap()),
None => {
let mut rng = OsRng::new()?;
// 128 bit salt
let salt: Vec<u8> = rng.gen::<[u8; 16]>().to_vec();
db_conn.set_db_property("salt", &salt.to_base64(base64::STANDARD));
Ok(salt)
}
}
}
pub fn decrypt_value(value: Vec<u8>, master_key: [u8; 32], iv: [u8; 32], mac: [u8; 32]) -> Vec<u8> {
let mut hmac = Hmac::new(Sha256::new(), &master_key);
hmac.input(&value[..]);
if hmac.result() != MacResult::new(&mac) {
panic!("Mac did not match, corrupted data");
}
let mut cipher = aes::ctr(KeySize::KeySize256, &master_key, &iv);
let mut output: Vec<u8> = vec![0; value.len() as usize];
cipher.process(&value, output.as_mut_slice());
output
}
pub fn encrypt_value(value: &str, master_key: [u8; 32]) -> EncryptedValue {
let mut random = OsRng::new().unwrap();
let iv: [u8; 32] = random.gen::<[u8; 32]>();
let mut cipher = aes::ctr(KeySize::KeySize256, &master_key, &iv);
let mut output: Vec<u8> = vec![0; value.len() as usize];
cipher.process(value.as_bytes(), output.as_mut_slice());
let mut hmac = Hmac::new(Sha256::new(), &master_key);
hmac.input(&output[..]);
EncryptedValue {
ciphertext: output,
iv,
mac: hmac.result(),
}
}
#[cfg(test)]
mod tests {
use crate::crypt::{encrypt_value, EncryptedValue};
use rusqlite::{Connection, NO_PARAMS};
#[test]
fn test_encrypted_value_round_trip() {
// Test that writing a value to the DB and reading it back
// doesn't result in any corruption
let db = Connection::open_in_memory().expect("Failed to open DB");
db.execute_batch("CREATE TABLE test (content BLOB);")
.expect("Failed to create table");
let master_key: [u8; 32] = [0u8; 32];
let encrypted_value = encrypt_value("hunter2", master_key);
db.execute(
"INSERT INTO test (content) VALUES ($1)",
&[&encrypted_value],
)
.expect("Failed to insert value into DB");
let mut stmt = db
.prepare("SELECT * FROM test")
.expect("Failed to prepare statement");
let rows: Vec<Result<EncryptedValue, rusqlite::Error>> = stmt
.query_map(NO_PARAMS, |row| {
let val: EncryptedValue = row.get(0).expect("Failed to get element from row");
Ok(val)
})
.expect("Failed to get rows")
.collect();
assert_eq!(rows.len(), 1);
for returned_result in rows {
let returned_value: EncryptedValue = returned_result.expect("Bad value returned");
assert_eq!(returned_value.ciphertext, encrypted_value.ciphertext);
assert_eq!(returned_value.iv, encrypted_value.iv);
assert_eq!(returned_value.mac.code(), encrypted_value.mac.code());
}
}
}