diff --git a/src/endpoints.rs b/src/endpoints.rs new file mode 100644 index 0000000..6effbfc --- /dev/null +++ b/src/endpoints.rs @@ -0,0 +1,48 @@ +use std::sync::{Arc, RwLock}; + +use rocket::{get, http::RawStr, response::{status::NotFound, Redirect}, serde::{self, json::Json}, State}; +use serde::Serialize; + +use crate::{database::Database, get_id, settings::Settings}; + +/// An endpoint to obtain information about the server's capabilities +#[get("/info")] +pub fn server_info(settings: &State) -> Json { + Json(ServerInfo { + max_filesize: settings.max_filesize, + max_duration: settings.duration.maximum.num_seconds() as u32, + default_duration: settings.duration.default.num_seconds() as u32, + allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(), + }) +} + +#[derive(Serialize, Debug)] +#[serde(crate = "rocket::serde")] +pub struct ServerInfo { + max_filesize: u64, + max_duration: u32, + default_duration: u32, + #[serde(skip_serializing_if = "Vec::is_empty")] + allowed_durations: Vec, +} + +/// Look up the hash of a file to find it. This only returns the first +/// hit for a hash, so different filenames may not be found. +#[get("/f/")] +pub fn lookup( + db: &State>>, + id: &str +) -> Result> { + for file in db.read().unwrap().files.values() { + if file.hash().to_hex()[0..10].to_string() == id { + let filename = get_id( + file.name(), + *file.hash() + ); + let filename = RawStr::new(&filename).percent_encode().to_string(); + return Ok(Redirect::to(format!("/files/{}", filename))) + } + } + + Err(NotFound(())) +} diff --git a/src/main.rs b/src/main.rs index 5106dd4..b8ec548 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,21 @@ mod database; mod strings; mod settings; +mod endpoints; +mod utils; + +use std::{fs, sync::{Arc, RwLock}}; -use std::{fs, path::Path, sync::{Arc, RwLock}}; -use blake3::Hash; use chrono::{DateTime, TimeDelta, Utc}; use database::{clean_loop, Database, MochiFile}; +use endpoints::{lookup, server_info}; use log::info; use rocket::{ - data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::{ContentType, RawStr}, post, response::{content::{RawCss, RawJavaScript}, status::NotFound, Redirect}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State + data::{Limits, ToByteUnit}, form::Form, fs::{FileServer, Options, TempFile}, get, http::ContentType, post, response::content::{RawCss, RawJavaScript}, routes, serde::{json::Json, Serialize}, tokio, Config, FromForm, State }; use settings::Settings; use strings::{parse_time_string, to_pretty_time}; +use utils::{get_id, hash_file}; use uuid::Uuid; use maud::{html, Markup, DOCTYPE, PreEscaped}; @@ -216,61 +220,6 @@ struct ClientResponse { pub expires: Option>, } -/// Get a filename based on the file's hashed name -fn get_id(name: &str, hash: Hash) -> String { - hash.to_hex()[0..10].to_string() + "_" + name -} - -/// Get the Blake3 hash of a file, without reading it all into memory, and also get the size -async fn hash_file>(input: &P) -> Result { - let mut hasher = blake3::Hasher::new(); - hasher.update_mmap_rayon(input)?; - - Ok(hasher.finalize()) -} - -/// An endpoint to obtain information about the server's capabilities -#[get("/info")] -fn server_info(settings: &State) -> Json { - Json(ServerInfo { - max_filesize: settings.max_filesize, - max_duration: settings.duration.maximum.num_seconds() as u32, - default_duration: settings.duration.default.num_seconds() as u32, - allowed_durations: settings.duration.allowed.clone().into_iter().map(|t| t.num_seconds() as u32).collect(), - }) -} - -#[derive(Serialize, Debug)] -#[serde(crate = "rocket::serde")] -struct ServerInfo { - max_filesize: u64, - max_duration: u32, - default_duration: u32, - #[serde(skip_serializing_if = "Vec::is_empty")] - allowed_durations: Vec, -} - -/// Look up the hash of a file to find it. This only returns the first -/// hit for a hash, so different filenames may not be found. -#[get("/f/")] -fn lookup( - db: &State>>, - id: &str -) -> Result> { - for file in db.read().unwrap().files.values() { - if file.hash().to_hex()[0..10].to_string() == id { - let filename = get_id( - file.name(), - *file.hash() - ); - let filename = RawStr::new(&filename).percent_encode().to_string(); - return Ok(Redirect::to(format!("/files/{}", filename))) - } - } - - Err(NotFound(())) -} - #[rocket::main] async fn main() { // Get or create config file diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..7a0faae --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,15 @@ +use std::path::Path; +use blake3::Hash; + +/// Get a filename based on the file's hashed name +pub fn get_id(name: &str, hash: Hash) -> String { + hash.to_hex()[0..10].to_string() + "_" + name +} + +/// Get the Blake3 hash of a file, without reading it all into memory, and also get the size +pub async fn hash_file>(input: &P) -> Result { + let mut hasher = blake3::Hasher::new(); + hasher.update_mmap_rayon(input)?; + + Ok(hasher.finalize()) +}