diff --git a/src/types/util.rs b/src/types/util.rs index 7eab912..6480ed7 100644 --- a/src/types/util.rs +++ b/src/types/util.rs @@ -259,15 +259,33 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( escape_character: char, escapable_characters: C, ) -> Cow<'s, str> { - let mut state = CoalesceWhitespaceEscaped::Normal; + let mut state = CoalesceWhitespaceEscaped::Normal { + in_whitespace: false, + }; for (offset, c) in input.char_indices() { state = match (state, c) { - (CoalesceWhitespaceEscaped::Normal, c) if c == escape_character => { + (CoalesceWhitespaceEscaped::Normal { in_whitespace: _ }, c) + if c == escape_character => + { CoalesceWhitespaceEscaped::NormalEscaping { escape_offset: offset, } } - (CoalesceWhitespaceEscaped::Normal, ' ' | '\t' | '\r' | '\n') => { + (CoalesceWhitespaceEscaped::Normal { in_whitespace }, ' ') => { + if in_whitespace { + let mut ret = String::with_capacity(input.len()); + ret.push_str(&input[..offset]); + CoalesceWhitespaceEscaped::RequiresMutation { + in_whitespace: true, + ret, + } + } else { + CoalesceWhitespaceEscaped::Normal { + in_whitespace: true, + } + } + } + (CoalesceWhitespaceEscaped::Normal { in_whitespace: _ }, '\t' | '\r' | '\n') => { let mut ret = String::with_capacity(input.len()); ret.push_str(&input[..offset]); ret.push(' '); @@ -276,7 +294,11 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( ret, } } - (CoalesceWhitespaceEscaped::Normal, _) => CoalesceWhitespaceEscaped::Normal, + (CoalesceWhitespaceEscaped::Normal { in_whitespace: _ }, _) => { + CoalesceWhitespaceEscaped::Normal { + in_whitespace: false, + } + } (CoalesceWhitespaceEscaped::NormalEscaping { escape_offset }, c) if escapable_characters(c) => { @@ -290,9 +312,15 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( } } + (CoalesceWhitespaceEscaped::NormalEscaping { escape_offset: _ }, ' ') => { + // We didn't escape the character so continue as normal. + CoalesceWhitespaceEscaped::Normal { + in_whitespace: true, + } + } ( CoalesceWhitespaceEscaped::NormalEscaping { escape_offset: _ }, - ' ' | '\t' | '\r' | '\n', + '\t' | '\r' | '\n', ) => { // We didn't escape the character but we hit whitespace anyway. let mut ret = String::with_capacity(input.len()); @@ -305,7 +333,9 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( } (CoalesceWhitespaceEscaped::NormalEscaping { escape_offset: _ }, _) => { // We didn't escape the character so continue as normal. - CoalesceWhitespaceEscaped::Normal + CoalesceWhitespaceEscaped::Normal { + in_whitespace: false, + } } ( @@ -379,7 +409,6 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( ) => { ret.push(matched_escape_character); ret.push(c); - // TODO CoalesceWhitespaceEscaped::RequiresMutation { in_whitespace: false, ret, @@ -388,7 +417,7 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( } } match state { - CoalesceWhitespaceEscaped::Normal => Cow::Borrowed(input), + CoalesceWhitespaceEscaped::Normal { in_whitespace: _ } => Cow::Borrowed(input), CoalesceWhitespaceEscaped::NormalEscaping { escape_offset: _ } => Cow::Borrowed(input), CoalesceWhitespaceEscaped::RequiresMutation { in_whitespace: _, @@ -405,7 +434,9 @@ fn impl_coalesce_whitespace_escaped<'s, C: Fn(char) -> bool>( } enum CoalesceWhitespaceEscaped { - Normal, + Normal { + in_whitespace: bool, + }, NormalEscaping { escape_offset: usize, }, @@ -437,8 +468,7 @@ mod tests { let input = "foo bar baz"; let output = coalesce_whitespace_escaped('&', |c| "".contains(c))(input); assert_eq!(output, "foo bar baz"); - // TODO: Technically this should be a Borrowed but to keep the code simple for now we are treating all whitespace as causing ownership. - assert!(matches!(output, Cow::Owned(_))); + assert!(matches!(output, Cow::Borrowed(_))); Ok(()) } @@ -468,4 +498,14 @@ mod tests { assert!(matches!(output, Cow::Owned(_))); Ok(()) } + + #[test] + fn coalesce_whitespace_escaped_escape_mismatch_around_whitespace( + ) -> Result<(), Box> { + let input = "foo& bar &baz"; + let output = coalesce_whitespace_escaped('&', |c| "z".contains(c))(input); + assert_eq!(output, "foo& bar &baz"); + assert!(matches!(output, Cow::Borrowed(_))); + Ok(()) + } }