diff --git a/Cargo.lock b/Cargo.lock index f9f0b15..fb04ecc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -339,6 +339,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -567,6 +568,15 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -1890,10 +1900,14 @@ name = "webhook_bridge" version = "0.0.1" dependencies = [ "axum", + "base64 0.22.1", + "hmac", + "http-body-util", "k8s-openapi", "kube", "serde", "serde_json", + "sha2", "tokio", "tower-http", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 3e3e3dc..84aa584 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,12 +20,16 @@ include = [ [dependencies] # default form, http1, json, matched-path, original-uri, query, tokio, tower-log, tracing axum = { version = "0.7.5", default-features = false, features = ["tokio", "http1", "http2", "json"] } +base64 = "0.22.1" +hmac = "0.12.1" +http-body-util = "0.1.2" k8s-openapi = { version = "0.22.0", default-features = false, features = ["v1_30"] } # default client, config, rustls-tls kube = { version = "0.92.1", default-features = false, features = ["client", "config", "rustls-tls", "derive", "runtime"] } serde = { version = "1.0.204", features = ["derive"] } # default std serde_json = { version = "1.0.120", default-features = false, features = ["std"] } +sha2 = "0.10.8" tokio = { version = "1.38.0", default-features = false, features = ["macros", "process", "rt-multi-thread", "signal"] } tower-http = { version = "0.5.2", default-features = false, features = ["trace", "timeout"] } # default attributes, std, tracing-attributes diff --git a/src/main.rs b/src/main.rs index 7153142..2d6502f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use std::time::Duration; use axum::http::StatusCode; +use axum::middleware; use axum::routing::get; use axum::routing::post; use axum::Json; @@ -15,6 +16,7 @@ use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; use self::webhook::hook; +use self::webhook::verify_signature; mod hook_push; mod webhook; @@ -35,15 +37,16 @@ async fn main() -> Result<(), Box> { .expect("Set KUBECONFIG to a valid kubernetes config."); let app = Router::new() - .route("/health", get(health)) .route("/hook", post(hook)) + .layer(middleware::from_fn(verify_signature)) + .route("/health", get(health)) .layer(( TraceLayer::new_for_http(), // Add a timeout layer so graceful shutdown can't wait forever. TimeoutLayer::new(Duration::from_secs(600)), )); - let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?; + let listener = tokio::net::TcpListener::bind("0.0.0.0:9988").await?; tracing::info!("listening on {}", listener.local_addr().unwrap()); axum::serve(listener, app) .with_graceful_shutdown(shutdown_signal()) diff --git a/src/webhook.rs b/src/webhook.rs index f044ca9..86f46ef 100644 --- a/src/webhook.rs +++ b/src/webhook.rs @@ -1,17 +1,29 @@ +use std::future::Future; + use axum::async_trait; +use axum::body::Body; +use axum::body::Bytes; use axum::extract::FromRequest; use axum::extract::Request; use axum::http::HeaderMap; use axum::http::StatusCode; +use axum::middleware::Next; use axum::response::IntoResponse; use axum::response::Response; use axum::Json; use axum::RequestExt; +use base64::{engine::general_purpose, Engine as _}; +use hmac::Hmac; +use hmac::Mac; +use http_body_util::BodyExt; use serde::Serialize; +use sha2::Sha256; use tracing::debug; use crate::hook_push::HookPush; +type HmacSha256 = Hmac; + pub(crate) async fn hook( _headers: HeaderMap, payload: HookRequest, @@ -72,3 +84,68 @@ pub(crate) struct HookResponse { ok: bool, message: Option, } + +pub(crate) async fn verify_signature( + request: Request, + next: Next, +) -> Result { + let signature = request + .headers() + .get("X-Gitea-Signature") + .ok_or(StatusCode::BAD_REQUEST.into_response())?; + let signature = signature + .to_str() + .map_err(|_| StatusCode::BAD_REQUEST.into_response())?; + let signature = hex_to_bytes(signature).ok_or(StatusCode::BAD_REQUEST.into_response())?; + let secret = std::env::var("WEBHOOK_BRIDGE_HMAC_SECRET") + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())?; + + let request = + inspect_request_body(request, move |body| check_hash(body, secret, signature)).await?; + + Ok(next.run(request).await) +} + +async fn inspect_request_body(request: Request, inspector: F) -> Result +where + F: FnOnce(Bytes) -> Fut, + Fut: Future>, +{ + let (parts, body) = request.into_parts(); + + let bytes = body + .collect() + .await + .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response())? + .to_bytes(); + + let bytes = inspector(bytes).await?; + + Ok(Request::from_parts(parts, Body::from(bytes))) +} + +async fn check_hash(body: Bytes, secret: String, signature: Vec) -> Result { + tracing::info!("Checking signature {:02x?}", signature.as_slice()); + tracing::info!("Using secret {:?}", secret); + tracing::info!("and body {}", general_purpose::STANDARD.encode(&body)); + let mut mac = HmacSha256::new_from_slice(secret.as_bytes()) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response())?; + mac.update(&body); + mac.verify_slice(&signature) + .map_err(|e| (StatusCode::UNAUTHORIZED, e.to_string()).into_response())?; + Ok(body) +} + +fn hex_to_bytes(s: &str) -> Option> { + if s.len() % 2 == 0 { + (0..s.len()) + .step_by(2) + .map(|i| { + s.get(i..i + 2) + .and_then(|sub| u8::from_str_radix(sub, 16).ok()) + }) + .collect() + } else { + None + } +} diff --git a/test_webhook.bash b/test_webhook.bash index 9c89f06..4ce1167 100755 --- a/test_webhook.bash +++ b/test_webhook.bash @@ -5,220 +5,34 @@ IFS=$'\n\t' DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" function main() { - local payload - payload=$(cat <