From ab405fc1f8921eda8d0fff5c2b6c7e73d5f137b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tristan=20Dani=C3=ABl=20Maat?= <tm@tlater.net>
Date: Tue, 6 Sep 2022 15:14:42 +0100
Subject: [PATCH] Add proper error handling

---
 Cargo.lock    |  2 ++
 Cargo.toml    |  2 ++
 src/errors.rs | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++
 src/main.rs   | 46 ++++++++++++++++++-------------
 4 files changed, 107 insertions(+), 19 deletions(-)
 create mode 100644 src/errors.rs

diff --git a/Cargo.lock b/Cargo.lock
index db16587..e01309f 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1215,9 +1215,11 @@ dependencies = [
  "actix-files",
  "actix-web",
  "clap",
+ "derive_more",
  "env_logger",
  "handlebars",
  "log",
+ "serde_json",
 ]
 
 [[package]]
diff --git a/Cargo.toml b/Cargo.toml
index 30665c6..b2304a0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -7,6 +7,8 @@ edition = "2021"
 actix-files = "0.6.2"
 actix-web = "4.1.0"
 clap = { version = "3.2.17", features = ["derive"] }
+derive_more = "0.99.17"
 env_logger = "0.9.0"
 handlebars = { version = "4.3.3", features = ["dir_source"] }
 log = "0.4.17"
+serde_json = "1.0.83"
diff --git a/src/errors.rs b/src/errors.rs
new file mode 100644
index 0000000..2f44792
--- /dev/null
+++ b/src/errors.rs
@@ -0,0 +1,76 @@
+use actix_web::body::BoxBody;
+use actix_web::dev::ServiceResponse;
+use actix_web::http::header::ContentType;
+use actix_web::http::StatusCode;
+use actix_web::middleware::ErrorHandlerResponse;
+use actix_web::{web, HttpResponse, ResponseError};
+use derive_more::{Display, Error};
+use serde_json::json;
+
+use super::SharedData;
+
+#[derive(Debug, Display, Error)]
+pub enum UserError {
+    // #[display(fmt = "The page could not be found.")]
+    NotFound,
+    #[display(fmt = "Internal error. Try again later.")]
+    InternalError,
+}
+
+impl ResponseError for UserError {
+    fn error_response(&self) -> HttpResponse {
+        HttpResponse::build(self.status_code())
+            .insert_header(ContentType::html())
+            .body(self.to_string())
+    }
+
+    fn status_code(&self) -> StatusCode {
+        match *self {
+            UserError::NotFound => StatusCode::NOT_FOUND,
+            UserError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
+        }
+    }
+}
+
+pub fn generic_error<B>(
+    res: ServiceResponse<B>,
+) -> actix_web::Result<ErrorHandlerResponse<BoxBody>> {
+    let data = res
+        .request()
+        .app_data::<web::Data<SharedData>>()
+        .map(|t| t.get_ref());
+
+    let status_code = res.response().status();
+    let message = if let Some(error) = status_code.canonical_reason() {
+        error
+    } else {
+        ""
+    };
+
+    let response = match data {
+        Some(SharedData {
+            handlebars,
+            config: _,
+        }) => {
+            let body = handlebars
+                .render(
+                    "error",
+                    &json!({
+                        "message": message,
+                        "status_code": status_code.as_u16()
+                    }),
+                )
+                .map_err(|_| UserError::InternalError)?;
+
+            HttpResponse::build(res.status())
+                .content_type(ContentType::html())
+                .body(body)
+        }
+        None => Err(UserError::InternalError)?,
+    };
+
+    Ok(ErrorHandlerResponse::Response(ServiceResponse::new(
+        res.into_parts().0,
+        response.map_into_left_body(),
+    )))
+}
diff --git a/src/main.rs b/src/main.rs
index 1a4fcc4..5af199f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,10 +1,17 @@
+#![allow(dead_code)]
 use std::net::SocketAddr;
+use std::path::PathBuf;
 
 use actix_files::NamedFile;
+use actix_web::http::StatusCode;
+use actix_web::middleware::ErrorHandlers;
 use actix_web::{get, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
 use clap::Parser;
 use handlebars::Handlebars;
-use std::path::PathBuf;
+
+mod errors;
+
+use errors::{generic_error, UserError};
 
 #[derive(Parser, Debug, Clone)]
 struct Config {
@@ -21,16 +28,17 @@ struct Config {
     dev_mode: bool,
 }
 
+#[derive(Debug)]
 struct SharedData<'a> {
     handlebars: Handlebars<'a>,
     config: Config,
 }
 
-#[get("/{filename:.*.html}")]
+#[get(r"/{filename:.*\.html}")]
 async fn template(
     shared: web::Data<SharedData<'_>>,
     req: HttpRequest,
-) -> Result<impl Responder, Box<dyn std::error::Error>> {
+) -> actix_web::Result<impl Responder> {
     let path = req
         .match_info()
         .query("filename")
@@ -38,22 +46,21 @@ async fn template(
         .expect("only paths with this suffix should get here");
 
     if shared.handlebars.has_template(path) {
-        let body = shared.handlebars.render(path, &())?;
-        Ok(HttpResponse::Ok().body(body))
-    } else {
         let body = shared
             .handlebars
-            .render("404", &())
-            .expect("404 template not found");
-        Ok(HttpResponse::NotFound().body(body))
+            .render(path, &())
+            .map_err(|_| UserError::InternalError)?;
+        Ok(HttpResponse::Ok().body(body))
+    } else {
+        Err(UserError::NotFound)?
     }
 }
 
-#[get("/{filename:.*}")]
+#[get("/{filename:.*[^/]+}")]
 async fn static_file(
     shared: web::Data<SharedData<'_>>,
     req: HttpRequest,
-) -> Result<impl Responder, Box<dyn std::error::Error>> {
+) -> actix_web::Result<impl Responder> {
     let requested = req.match_info().query("filename");
 
     match shared
@@ -67,17 +74,13 @@ async fn static_file(
         //
         // i.e., don't serve up /etc/passwd
         Ok(path) if path.starts_with(&shared.config.template_directory) => {
-            let file = NamedFile::open_async(path).await?;
+            let file = NamedFile::open_async(path)
+                .await
+                .map_err(|_| UserError::NotFound)?;
             Ok(file.use_last_modified(false).respond_to(&req))
         }
         // Any other cases should 404
-        _ => {
-            let body = shared
-                .handlebars
-                .render("404", &())
-                .expect("404 template not found");
-            Ok(HttpResponse::NotFound().body(body))
-        }
+        _ => Err(UserError::NotFound)?,
     }
 }
 
@@ -101,6 +104,11 @@ async fn main() -> Result<(), std::io::Error> {
 
     HttpServer::new(move || {
         App::new()
+            .wrap(
+                ErrorHandlers::new()
+                    .handler(StatusCode::NOT_FOUND, generic_error)
+                    .handler(StatusCode::INTERNAL_SERVER_ERROR, generic_error),
+            )
             .app_data(shared_data.clone())
             .service(template)
             .service(static_file)