From 41cc65e7d3a426bce5b250fa693e648e4e287f33 Mon Sep 17 00:00:00 2001 From: Tom Alexander Date: Sun, 14 Jul 2024 19:01:10 -0400 Subject: [PATCH] Add graceful shutdown. --- Cargo.lock | 1 + Cargo.toml | 4 ++-- src/main.rs | 38 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7235c67..a971788 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -729,6 +729,7 @@ dependencies = [ "http-body", "http-body-util", "pin-project-lite", + "tokio", "tower-layer", "tower-service", "tracing", diff --git a/Cargo.toml b/Cargo.toml index aa6c998..037acdd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,8 @@ axum = { version = "0.7.5", default-features = false, features = ["tokio", "http serde = { version = "1.0.204", features = ["derive"] } # default std serde_json = { version = "1.0.120", default-features = false, features = ["std"] } -tokio = { version = "1.38.0", default-features = false, features = ["macros", "process", "rt", "rt-multi-thread"] } -tower-http = { version = "0.5.2", default-features = false, features = ["trace"] } +tokio = { version = "1.38.0", default-features = false, features = ["macros", "process", "rt", "rt-multi-thread", "signal"] } +tower-http = { version = "0.5.2", default-features = false, features = ["trace", "timeout"] } # default attributes, std, tracing-attributes tracing = { version = "0.1.40", default-features = false, features = ["attributes", "std", "tracing-attributes", "async-await"] } # default alloc, ansi, fmt, nu-ansi-term, registry, sharded-slab, smallvec, std, thread_local, tracing-log diff --git a/src/main.rs b/src/main.rs index 96575a7..f035abb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,10 +1,14 @@ #![forbid(unsafe_code)] +use std::time::Duration; + use axum::http::StatusCode; use axum::routing::get; use axum::routing::post; use axum::Json; use axum::Router; use serde::Serialize; +use tokio::signal; +use tower_http::timeout::TimeoutLayer; use tower_http::trace::TraceLayer; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; @@ -27,14 +31,44 @@ async fn main() -> Result<(), Box> { let app = Router::new() .route("/health", get(health)) .route("/hook", post(hook)) - .layer(TraceLayer::new_for_http()); + .layer(( + TraceLayer::new_for_http(), + // Add a timeout layer so graceful shutdown can't wait forever. + TimeoutLayer::new(Duration::from_secs(600)), + )); let listener = tokio::net::TcpListener::bind("0.0.0.0:8080").await?; tracing::info!("listening on {}", listener.local_addr().unwrap()); - axum::serve(listener, app).await?; + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await?; Ok(()) } +async fn shutdown_signal() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } +} + async fn health() -> (StatusCode, Json) { (StatusCode::OK, Json(HealthResponse { ok: true })) }