diff --git a/src/main.rs b/src/main.rs index 79ab501..85a5ef0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use tower_http::set_header::SetResponseHeaderLayer; mod error; mod owner_tree; mod parse; +mod rtrim_iterator; mod sexp; #[tokio::main] diff --git a/src/owner_tree.rs b/src/owner_tree.rs index 5beb9ea..048a511 100644 --- a/src/owner_tree.rs +++ b/src/owner_tree.rs @@ -1,6 +1,9 @@ use serde::Serialize; -use crate::sexp::{sexp_with_padding, Token}; +use crate::{ + rtrim_iterator::RTrimIterator, + sexp::{sexp_with_padding, Token}, +}; pub fn build_owner_tree<'a>( body: &'a str, @@ -207,6 +210,7 @@ fn get_line_numbers<'s>( begin: u32, end: u32, ) -> Result<(u32, u32), Box> { + // This is used for highlighting which lines contain text relevant to the token, so even if a token does not extend all the way to the end of the line, the end_line figure will be the following line number (since the range is exclusive, not inclusive). let start_line = original_source .chars() .into_iter() @@ -214,12 +218,15 @@ fn get_line_numbers<'s>( .filter(|x| *x == '\n') .count() + 1; - let end_line = original_source - .chars() - .into_iter() - .take(usize::try_from(end)? - 1) - .filter(|x| *x == '\n') - .count() - + 1; + let end_line = { + let content_up_to_and_including_token = original_source + .chars() + .into_iter() + .take(usize::try_from(end)? - 1); + // Remove the trailing newline (if there is one) because we're going to add an extra line regardless of whether or not this ends with a new line. + let without_trailing_newline = RTrimIterator::new(content_up_to_and_including_token, '\n'); + without_trailing_newline.filter(|x| *x == '\n').count() + 2 + }; + Ok((u32::try_from(start_line)?, u32::try_from(end_line)?)) } diff --git a/src/rtrim_iterator.rs b/src/rtrim_iterator.rs new file mode 100644 index 0000000..3f5e346 --- /dev/null +++ b/src/rtrim_iterator.rs @@ -0,0 +1,86 @@ +/// Removes 1 character from the end of an iterator if it matches needle +pub struct RTrimIterator { + iter: I, + needle: char, + buffer: Option, +} + +impl Iterator for RTrimIterator +where + I: Iterator, +{ + type Item = char; + + fn next(&mut self) -> Option { + loop { + match (self.buffer, self.iter.next()) { + (None, None) => { + // We reached the end of the list and have an empty buffer, meaning the string did not end with the needle character. + return None; + } + (None, Some(chr)) if chr == self.needle => { + // We came across an instance of needle, buffer it and loop again because we do not know if this is the end of the string. + self.buffer = Some(chr); + } + (None, Some(chr)) => { + // We have an empty buffer and the next character is not the needle character, return it immediately. + return Some(chr); + } + (Some(buf), None) if buf == self.needle => { + // We reached the end of the list and have the specified needle in the buffer where it will stay forever. + return None; + } + (Some(_), None) => { + // We reached the end of the list and the buffered character is not the needle character, so write it out. + return self.buffer.take(); + } + (Some(_), Some(chr)) => { + // We have a buffered character, but it is not the end of the string, so regardless of its contents we can write it out. + return self.buffer.replace(chr); + } + }; + } + } +} + +impl RTrimIterator { + pub fn new(iter: I, needle: char) -> RTrimIterator { + RTrimIterator { + iter, + needle, + buffer: None, + } + } +} + +mod tests { + use super::*; + + #[test] + fn no_match() { + let input = "abcd"; + let output: String = RTrimIterator::new(input.chars(), '\n').collect(); + assert_eq!(output, input); + } + + #[test] + fn middle_match() { + let input = "ab\ncd"; + let output: String = RTrimIterator::new(input.chars(), '\n').collect(); + assert_eq!(output, input); + } + + #[test] + fn end_match() { + let input = "abcd\n"; + let output: String = RTrimIterator::new(input.chars(), '\n').collect(); + assert_eq!(output, "abcd"); + } + + #[test] + fn double_match() { + let input = "abcd\n\n"; + let output: String = RTrimIterator::new(input.chars(), '\n').collect(); + assert_eq!(output, "abcd\n"); + } +}